diff --git a/ci/minimum_versions.py b/ci/minimum_versions.py index 08808d002d9..21123bffcd6 100644 --- a/ci/minimum_versions.py +++ b/ci/minimum_versions.py @@ -30,6 +30,7 @@ "coveralls", "pip", "pytest", + "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mypy-plugins", diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index 5f5db4a0f18..65780d91949 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -28,6 +28,7 @@ dependencies: - pip - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/all-but-numba.yml b/ci/requirements/all-but-numba.yml index 712055a0ec2..23c38cc8267 100644 --- a/ci/requirements/all-but-numba.yml +++ b/ci/requirements/all-but-numba.yml @@ -41,6 +41,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/bare-min-and-scipy.yml b/ci/requirements/bare-min-and-scipy.yml index bb25af67651..d4a61586d82 100644 --- a/ci/requirements/bare-min-and-scipy.yml +++ b/ci/requirements/bare-min-and-scipy.yml @@ -7,6 +7,7 @@ dependencies: - coveralls - pip - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/bare-minimum.yml b/ci/requirements/bare-minimum.yml index fafc1aa034a..777ff09b3e6 100644 --- a/ci/requirements/bare-minimum.yml +++ b/ci/requirements/bare-minimum.yml @@ -7,6 +7,7 @@ dependencies: - coveralls - pip - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-3.14.yml b/ci/requirements/environment-3.14.yml index 06c4df82663..d4d47d85536 100644 --- a/ci/requirements/environment-3.14.yml +++ b/ci/requirements/environment-3.14.yml @@ -37,6 +37,7 @@ dependencies: - pyarrow # pandas raises a deprecation warning without this, breaking doctests - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows-3.14.yml b/ci/requirements/environment-windows-3.14.yml index dd48add6b73..e86d57beb95 100644 --- a/ci/requirements/environment-windows-3.14.yml +++ b/ci/requirements/environment-windows-3.14.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 3213ef687d3..7c0d4dd9231 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -32,6 +32,7 @@ dependencies: - pyarrow # importing dask.dataframe raises an ImportError without this - pydap - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index cc33d8b4681..84441625e4c 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -39,6 +39,7 @@ dependencies: - pydap - pydap-server - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index 9183433e801..add738630f1 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -40,6 +40,7 @@ dependencies: - pip - pydap=3.5.0 - pytest + - pytest-asyncio - pytest-cov - pytest-env - pytest-mypy-plugins diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 5b9fa70d6b7..9f56fca1472 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -228,6 +228,7 @@ Variable.isnull Variable.item Variable.load + Variable.load_async Variable.max Variable.mean Variable.median diff --git a/doc/api.rst b/doc/api.rst index fc862c21e4c..f7bb382e922 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -11,30 +11,3 @@ and examples, refer to the relevant chapters in the main part of the documentation. See also: :ref:`public-api` and :ref:`api-stability`. - -.. toctree:: - :maxdepth: 1 - - api/top-level - api/dataset - api/dataarray - api/datatree - api/coordinates - api/indexes - api/ufuncs - api/io - api/encoding - api/plotting - api/groupby - api/rolling - api/coarsen - api/rolling-exp - api/weighted - api/resample - api/accessors - api/tutorial - api/testing - api/backends - api/exceptions - api/advanced - api/deprecated diff --git a/doc/internals/how-to-add-new-backend.rst b/doc/internals/how-to-add-new-backend.rst index d3b5c3a9267..b5dfe3b5f8e 100644 --- a/doc/internals/how-to-add-new-backend.rst +++ b/doc/internals/how-to-add-new-backend.rst @@ -331,10 +331,12 @@ information on plugins. How to support lazy loading +++++++++++++++++++++++++++ -If you want to make your backend effective with big datasets, then you should -support lazy loading. -Basically, you shall replace the :py:class:`numpy.ndarray` inside the -variables with a custom class that supports lazy loading indexing. +If you want to make your backend effective with big datasets, then you should take advantage of xarray's +support for lazy loading and indexing. + +Basically, when your backend constructs the ``Variable`` objects, +you need to replace the :py:class:`numpy.ndarray` inside the +variables with a custom :py:class:`~xarray.backends.BackendArray` subclass that supports lazy loading and indexing. See the example below: .. code-block:: python @@ -345,25 +347,27 @@ See the example below: Where: -- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a class - provided by Xarray that manages the lazy loading. -- ``MyBackendArray`` shall be implemented by the backend and shall inherit +- :py:class:`~xarray.core.indexing.LazilyIndexedArray` is a wrapper class + provided by Xarray that manages the lazy loading and indexing. +- ``MyBackendArray`` should be implemented by the backend and must inherit from :py:class:`~xarray.backends.BackendArray`. BackendArray subclassing ^^^^^^^^^^^^^^^^^^^^^^^^ -The BackendArray subclass shall implement the following method and attributes: +The BackendArray subclass must implement the following method and attributes: -- the ``__getitem__`` method that takes in input an index and returns a - `NumPy `__ array -- the ``shape`` attribute +- the ``__getitem__`` method that takes an index as an input and returns a + `NumPy `__ array, +- the ``shape`` attribute, - the ``dtype`` attribute. -Xarray supports different type of :doc:`/user-guide/indexing`, that can be -grouped in three types of indexes +It may also optionally implement an additional ``async_getitem`` method. + +Xarray supports different types of :doc:`/user-guide/indexing`, that can be +grouped in three types of indexes: :py:class:`~xarray.core.indexing.BasicIndexer`, -:py:class:`~xarray.core.indexing.OuterIndexer` and +:py:class:`~xarray.core.indexing.OuterIndexer`, and :py:class:`~xarray.core.indexing.VectorizedIndexer`. This implies that the implementation of the method ``__getitem__`` can be tricky. In order to simplify this task, Xarray provides a helper function, @@ -419,8 +423,22 @@ input the ``key``, the array ``shape`` and the following parameters: For more details see :py:class:`~xarray.core.indexing.IndexingSupport` and :ref:`RST indexing`. +Async support +^^^^^^^^^^^^^ + +Backends can also optionally support loading data asynchronously via xarray's asynchronous loading methods +(e.g. ``~xarray.Dataset.load_async``). +To support async loading the ``BackendArray`` subclass must additionally implement the ``BackendArray.async_getitem`` method. + +Note that implementing this method is only necessary if you want to be able to load data from different xarray objects concurrently. +Even without this method your ``BackendArray`` implementation is still free to concurrently load chunks of data for a single ``Variable`` itself, +so long as it does so behind the synchronous ``__getitem__`` interface. + +Dask support +^^^^^^^^^^^^ + In order to support `Dask Distributed `__ and -:py:mod:`multiprocessing`, ``BackendArray`` subclass should be serializable +:py:mod:`multiprocessing`, the ``BackendArray`` subclass should be serializable either with :ref:`io.pickle` or `cloudpickle `__. That implies that all the reference to open files should be dropped. For diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 06a3c2cb22d..d5d13e6a32b 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -13,6 +13,8 @@ v2025.07.2 (unreleased) New Features ~~~~~~~~~~~~ +- Added new asynchronous loading methods :py:meth:`~xarray.Dataset.load_async`, :py:meth:`~xarray.DataArray.load_async`, :py:meth:`~xarray.Variable.load_async`. + (:issue:`10326`, :pull:`10327`) By `Tom Nicholas `_. Breaking changes ~~~~~~~~~~~~~~~~ @@ -36,12 +38,10 @@ Deprecations Bug fixes ~~~~~~~~~ - - Fix Pydap Datatree backend testing. Testing now compares elements of (unordered) two sets (before, lists) (:pull:`10525`). By `Miguel Jimenez-Urias `_. - Fix ``KeyError`` when passing a ``dim`` argument different from the default to ``convert_calendar`` (:pull:`10544`). By `Eric Jansen `_. - - Fix transpose of boolean arrays read from disk. (:issue:`10536`) By `Deepak Cherian `_. - Fix detection of the ``h5netcdf`` backend. Xarray now selects ``h5netcdf`` if the default ``netCDF4`` engine is not available (:issue:`10401`, :pull:`10557`). diff --git a/pyproject.toml b/pyproject.toml index bc899596b4c..7426ff05518 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ dev = [ "pytest-mypy-plugins", "pytest-timeout", "pytest-xdist", + "pytest-asyncio", "ruff>=0.8.0", "sphinx", "sphinx_autosummary_accessors", diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 8b56c8a2bf9..a9f21e9a3bd 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -270,10 +270,17 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500 class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): __slots__ = () + async def async_getitem(key: indexing.ExplicitIndexer) -> np.typing.ArrayLike: + raise NotImplementedError("Backend does not not support asynchronous loading") + def get_duck_array(self, dtype: np.typing.DTypeLike = None): key = indexing.BasicIndexer((slice(None),) * self.ndim) return self[key] # type: ignore[index] + async def async_get_duck_array(self, dtype: np.typing.DTypeLike = None): + key = indexing.BasicIndexer((slice(None),) * self.ndim) + return await self.async_getitem(key) # type: ignore[index] + class AbstractDataStore: __slots__ = () diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 8b26a6b40ec..af5e395cd72 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -180,12 +180,23 @@ def encode_zarr_attr_value(value): return encoded +def has_zarr_async_index() -> bool: + try: + import zarr + + return hasattr(zarr.AsyncArray, "oindex") + except (ImportError, AttributeError): + return False + + class ZarrArrayWrapper(BackendArray): __slots__ = ("_array", "dtype", "shape") def __init__(self, zarr_array): # some callers attempt to evaluate an array if an `array` property exists on the object. # we prefix with _ to avoid this inference. + + # TODO type hint this? self._array = zarr_array self.shape = self._array.shape @@ -213,6 +224,33 @@ def _vindex(self, key): def _getitem(self, key): return self._array[key] + async def _async_getitem(self, key): + if not _zarr_v3(): + raise NotImplementedError( + "For lazy basic async indexing with zarr, zarr-python=>v3.0.0 is required" + ) + + async_array = self._array._async_array + return await async_array.getitem(key) + + async def _async_oindex(self, key): + if not has_zarr_async_index(): + raise NotImplementedError( + "For lazy orthogonal async indexing with zarr, zarr-python=>v3.1.2 is required" + ) + + async_array = self._array._async_array + return await async_array.oindex.getitem(key) + + async def _async_vindex(self, key): + if not has_zarr_async_index(): + raise NotImplementedError( + "For lazy vectorized async indexing with zarr, zarr-python=>v3.1.2 is required" + ) + + async_array = self._array._async_array + return await async_array.vindex.getitem(key) + def __getitem__(self, key): array = self._array if isinstance(key, indexing.BasicIndexer): @@ -228,6 +266,18 @@ def __getitem__(self, key): # if self.ndim == 0: # could possibly have a work-around for 0d data here + async def async_getitem(self, key): + array = self._array + if isinstance(key, indexing.BasicIndexer): + method = self._async_getitem + elif isinstance(key, indexing.VectorizedIndexer): + method = self._async_vindex + elif isinstance(key, indexing.OuterIndexer): + method = self._async_oindex + return await indexing.async_explicit_indexing_adapter( + key, array.shape, indexing.IndexingSupport.VECTORIZED, method + ) + def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): """ diff --git a/xarray/coding/common.py b/xarray/coding/common.py index 0e8d7e1955e..79e5e7502b3 100644 --- a/xarray/coding/common.py +++ b/xarray/coding/common.py @@ -79,6 +79,9 @@ def __getitem__(self, key): def get_duck_array(self): return self.func(self.array.get_duck_array()) + async def async_get_duck_array(self): + return self.func(await self.array.async_get_duck_array()) + def __repr__(self) -> str: return f"{type(self).__name__}({self.array!r}, func={self.func!r}, dtype={self.dtype!r})" diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 73b0eb19a64..47ffd4e1520 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1160,6 +1160,14 @@ def load(self, **kwargs) -> Self: self._coords = new._coords return self + async def load_async(self, **kwargs) -> Self: + temp_ds = self._to_temp_dataset() + ds = await temp_ds.load_async(**kwargs) + new = self._from_temp_dataset(ds) + self._variable = new._variable + self._coords = new._coords + return self + def compute(self, **kwargs) -> Self: """Manually trigger loading of this array's data from disk or a remote source into memory and return a new array. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index acc7d1f17f6..b1cfafe8f34 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import copy import datetime import math @@ -538,24 +539,50 @@ def load(self, **kwargs) -> Self: dask.compute """ # access .data to coerce everything to numpy or dask arrays - lazy_data = { + chunked_data = { k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) } - if lazy_data: - chunkmanager = get_chunked_array_type(*lazy_data.values()) + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) # evaluate all the chunked arrays simultaneously evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( - *lazy_data.values(), **kwargs + *chunked_data.values(), **kwargs ) - for k, data in zip(lazy_data, evaluated_data, strict=False): + for k, data in zip(chunked_data, evaluated_data, strict=False): self.variables[k].data = data # load everything else sequentially - for k, v in self.variables.items(): - if k not in lazy_data: - v.load() + [v.load() for k, v in self.variables.items() if k not in chunked_data] + + return self + + async def load_async(self, **kwargs) -> Self: + # TODO refactor this to pull out the common chunked_data codepath + + # this blocks on chunked arrays but not on lazily indexed arrays + + # access .data to coerce everything to numpy or dask arrays + chunked_data = { + k: v._data for k, v in self.variables.items() if is_chunked_array(v._data) + } + if chunked_data: + chunkmanager = get_chunked_array_type(*chunked_data.values()) + + # evaluate all the chunked arrays simultaneously + evaluated_data: tuple[np.ndarray[Any, Any], ...] = chunkmanager.compute( + *chunked_data.values(), **kwargs + ) + + for k, data in zip(chunked_data, evaluated_data, strict=False): + self.variables[k].data = data + + # load everything else concurrently + coros = [ + v.load_async() for k, v in self.variables.items() if k not in chunked_data + ] + await asyncio.gather(*coros) return self diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index c98175578f8..4a41fa4f269 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -516,13 +516,31 @@ def get_duck_array(self): return self.array -class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): - __slots__ = () +class IndexingAdapter: + """Marker class for indexing adapters. + + These classes translate between Xarray's indexing semantics and the underlying array's + indexing semantics. + """ def get_duck_array(self): key = BasicIndexer((slice(None),) * self.ndim) return self[key] + async def async_get_duck_array(self): + """These classes are applied to in-memory arrays, so specific async support isn't needed.""" + return self.get_duck_array() + + +class ExplicitlyIndexedNDArrayMixin(NDArrayMixin, ExplicitlyIndexed): + __slots__ = () + + def get_duck_array(self): + raise NotImplementedError + + async def async_get_duck_array(self): + raise NotImplementedError + def _oindex_get(self, indexer: OuterIndexer): raise NotImplementedError( f"{self.__class__.__name__}._oindex_get method should be overridden" @@ -646,19 +664,25 @@ def shape(self) -> _Shape: return self._shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): - array = apply_indexer(self.array, self.key) - else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): array = self.array[self.key] + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + async def async_get_duck_array(self): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) + else: + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def transpose(self, order): @@ -722,18 +746,25 @@ def shape(self) -> _Shape: return np.broadcast(*self.key.tuple).shape def get_duck_array(self): - if isinstance(self.array, ExplicitlyIndexedNDArrayMixin): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = self.array[self.key] + else: array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = array.get_duck_array() + return _wrap_numpy_scalars(array) + + async def async_get_duck_array(self): + from xarray.backends.common import BackendArray + + if isinstance(self.array, BackendArray): + array = await self.array.async_getitem(self.key) else: - # If the array is not an ExplicitlyIndexedNDArrayMixin, - # it may wrap a BackendArray so use its __getitem__ - array = self.array[self.key] - # self.array[self.key] is now a numpy array when - # self.array is a BackendArray subclass - # and self.key is BasicIndexer((slice(None, None, None),)) - # so we need the explicit check for ExplicitlyIndexed - if isinstance(array, ExplicitlyIndexed): - array = array.get_duck_array() + array = apply_indexer(self.array, self.key) + if isinstance(array, ExplicitlyIndexed): + array = await array.async_get_duck_array() return _wrap_numpy_scalars(array) def _updated_key(self, new_key: ExplicitIndexer): @@ -798,6 +829,9 @@ def _ensure_copied(self): def get_duck_array(self): return self.array.get_duck_array() + async def async_get_duck_array(self): + return await self.array.async_get_duck_array() + def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -838,12 +872,17 @@ class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin): def __init__(self, array): self.array = _wrap_numpy_scalars(as_indexable(array)) - def _ensure_cached(self): - self.array = as_indexable(self.array.get_duck_array()) - def get_duck_array(self): - self._ensure_cached() - return self.array.get_duck_array() + duck_array = self.array.get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array + + async def async_get_duck_array(self): + duck_array = await self.array.async_get_duck_array() + # ensure the array object is cached in-memory + self.array = as_indexable(duck_array) + return duck_array def _oindex_get(self, indexer: OuterIndexer): return type(self)(_wrap_numpy_scalars(self.array.oindex[indexer])) @@ -1028,6 +1067,21 @@ def explicit_indexing_adapter( return result +async def async_explicit_indexing_adapter( + key: ExplicitIndexer, + shape: _Shape, + indexing_support: IndexingSupport, + raw_indexing_method: Callable[..., Any], +) -> Any: + raw_key, numpy_indices = decompose_indexer(key, shape, indexing_support) + result = await raw_indexing_method(raw_key.tuple) + if numpy_indices.tuple: + # index the loaded duck array + indexable = as_indexable(result) + result = apply_indexer(indexable, numpy_indices) + return result + + def apply_indexer(indexable, indexer: ExplicitIndexer): """Apply an indexer to an indexable object.""" if isinstance(indexer, VectorizedIndexer): @@ -1527,7 +1581,7 @@ def is_fancy_indexer(indexer: Any) -> bool: return True -class NumpyIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class NumpyIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a NumPy array to use explicit indexing.""" __slots__ = ("array",) @@ -1606,7 +1660,7 @@ def __init__(self, array): self.array = array -class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class ArrayApiIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap an array API array to use explicit indexing.""" __slots__ = ("array",) @@ -1671,7 +1725,7 @@ def _assert_not_chunked_indexer(idxr: tuple[Any, ...]) -> None: ) -class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class DaskIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" __slots__ = ("array",) @@ -1747,7 +1801,7 @@ def transpose(self, order): return self.array.transpose(order) -class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class PandasIndexingAdapter(IndexingAdapter, ExplicitlyIndexedNDArrayMixin): """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" __slots__ = ("_dtype", "array") @@ -2004,7 +2058,9 @@ def copy(self, deep: bool = True) -> Self: return type(self)(array, self._dtype, self.level) -class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): +class CoordinateTransformIndexingAdapter( + IndexingAdapter, ExplicitlyIndexedNDArrayMixin +): """Wrap a CoordinateTransform as a lazy coordinate array. Supports explicit indexing (both outer and vectorized). diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 06d7218fe7c..145836da743 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -50,6 +50,7 @@ from xarray.namedarray.core import NamedArray, _raise_if_any_duplicate_dimensions from xarray.namedarray.parallelcompat import get_chunked_array_type from xarray.namedarray.pycompat import ( + async_to_duck_array, integer_types, is_0d_dask_array, is_chunked_array, @@ -995,6 +996,10 @@ def load(self, **kwargs): self._data = to_duck_array(self._data, **kwargs) return self + async def load_async(self, **kwargs): + self._data = await async_to_duck_array(self._data, **kwargs) + return self + def compute(self, **kwargs): """Manually trigger loading of this variable's data from disk or a remote source into memory and return a new variable. The original is diff --git a/xarray/namedarray/pycompat.py b/xarray/namedarray/pycompat.py index 68b6a7853bf..0fe5cfdf3b5 100644 --- a/xarray/namedarray/pycompat.py +++ b/xarray/namedarray/pycompat.py @@ -145,3 +145,21 @@ def to_duck_array(data: Any, **kwargs: dict[str, Any]) -> duckarray[_ShapeType, return data else: return np.asarray(data) # type: ignore[return-value] + + +async def async_to_duck_array( + data: Any, **kwargs: dict[str, Any] +) -> duckarray[_ShapeType, _DType]: + from xarray.core.indexing import ( + ExplicitlyIndexed, + ImplicitToExplicitIndexingAdapter, + IndexingAdapter, + ) + + if isinstance(data, IndexingAdapter): + # These wrap in-memory arrays, and async isn't needed + return data.get_duck_array() + elif isinstance(data, ExplicitlyIndexed | ImplicitToExplicitIndexingAdapter): + return await data.async_get_duck_array() # type: ignore[no-untyped-call, no-any-return] + else: + return to_duck_array(data, **kwargs) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 787c01eaf62..3b4e49c64d8 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -132,6 +132,7 @@ def _importorskip( has_zarr, requires_zarr = _importorskip("zarr") has_zarr_v3, requires_zarr_v3 = _importorskip("zarr", "3.0.0") has_zarr_v3_dtypes, requires_zarr_v3_dtypes = _importorskip("zarr", "3.1.0") +has_zarr_v3_async_oindex, requires_zarr_v3_async_oindex = _importorskip("zarr", "3.1.2") if has_zarr_v3: import zarr @@ -140,10 +141,15 @@ def _importorskip( # installing from git main is giving me a lower version than the # most recently released zarr has_zarr_v3_dtypes = hasattr(zarr.core, "dtype") + has_zarr_v3_async_oindex = hasattr(zarr.AsyncArray, "oindex") requires_zarr_v3_dtypes = pytest.mark.skipif( not has_zarr_v3_dtypes, reason="requires zarr>3.1.0" ) + requires_zarr_v3_async_oindex = pytest.mark.skipif( + not has_zarr_v3_async_oindex, reason="requires zarr>3.1.1" + ) + has_fsspec, requires_fsspec = _importorskip("fsspec") has_iris, requires_iris = _importorskip("iris") diff --git a/xarray/tests/test_async.py b/xarray/tests/test_async.py new file mode 100644 index 00000000000..0f0c97a0ec8 --- /dev/null +++ b/xarray/tests/test_async.py @@ -0,0 +1,235 @@ +import asyncio +from importlib import import_module +from typing import Any, Literal +from unittest.mock import patch + +import pytest + +import xarray as xr +import xarray.testing as xrt +from xarray.tests import ( + has_zarr, + has_zarr_v3, + has_zarr_v3_async_oindex, + requires_zarr, + requires_zarr_v3, +) +from xarray.tests.test_backends import ZARR_FORMATS +from xarray.tests.test_dataset import create_test_data + +if has_zarr: + import zarr +else: + zarr = None # type: ignore[assignment] + + +@pytest.fixture(scope="module", params=ZARR_FORMATS) +def store(request) -> "zarr.storage.MemoryStore": + memorystore = zarr.storage.MemoryStore({}) + + ds = create_test_data() + print(ds) + ds.to_zarr(memorystore, zarr_format=request.param, consolidated=False) # type: ignore[call-overload] + + return memorystore + + +def get_xr_obj( + store: "zarr.abc.store.Store", cls_name: Literal["Variable", "DataArray", "Dataset"] +): + ds = xr.open_zarr(store, consolidated=False, chunks=None) + + match cls_name: + case "Variable": + return ds["var1"].variable + case "DataArray": + return ds["var1"] + case "Dataset": + return ds + + +def _resolve_class_from_string(class_path: str) -> type[Any]: + """Resolve a string class path like 'zarr.AsyncArray' to the actual class.""" + module_path, class_name = class_path.rsplit(".", 1) + module = import_module(module_path) + return getattr(module, class_name) + + +@pytest.mark.asyncio +class TestAsyncLoad: + @requires_zarr_v3 + async def test_concurrent_load_multiple_variables(self, store) -> None: + target_class = zarr.AsyncArray + method_name = "getitem" + original_method = getattr(target_class, method_name) + + # the indexed coordinate variables is not lazy, so the create_test_dataset has 4 lazy variables in total + N_LAZY_VARS = 4 + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + # blocks upon loading the coordinate variables here + ds = xr.open_zarr(store, consolidated=False, chunks=None) + + # TODO we're not actually testing that these indexing methods are not blocking... + result_ds = await ds.load_async() + + mocked_meth.assert_called() + assert mocked_meth.call_count >= N_LAZY_VARS + mocked_meth.assert_awaited() + + xrt.assert_identical(result_ds, ds.load()) + + @requires_zarr_v3 + @pytest.mark.parametrize("cls_name", ["Variable", "DataArray", "Dataset"]) + async def test_concurrent_load_multiple_objects(self, store, cls_name) -> None: + N_OBJECTS = 5 + + target_class = zarr.AsyncArray + method_name = "getitem" + original_method = getattr(target_class, method_name) + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + xr_obj = get_xr_obj(store, cls_name) + + # TODO we're not actually testing that these indexing methods are not blocking... + coros = [xr_obj.load_async() for _ in range(N_OBJECTS)] + results = await asyncio.gather(*coros) + + mocked_meth.assert_called() + assert mocked_meth.call_count >= N_OBJECTS + mocked_meth.assert_awaited() + + for result in results: + xrt.assert_identical(result, xr_obj.load()) + + @requires_zarr_v3 + @pytest.mark.parametrize("cls_name", ["Variable", "DataArray", "Dataset"]) + @pytest.mark.parametrize( + "indexer, method, target_zarr_class", + [ + ({}, "sel", "zarr.AsyncArray"), + ({}, "isel", "zarr.AsyncArray"), + ({"dim2": 1.0}, "sel", "zarr.AsyncArray"), + ({"dim2": 2}, "isel", "zarr.AsyncArray"), + ({"dim2": slice(1.0, 3.0)}, "sel", "zarr.AsyncArray"), + ({"dim2": slice(1, 3)}, "isel", "zarr.AsyncArray"), + ( + {"dim2": [1.0, 3.0]}, + "sel", + "zarr.core.indexing.AsyncOIndex", + ), + ({"dim2": [1, 3]}, "isel", "zarr.core.indexing.AsyncOIndex"), + ( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1.0, 2.0], dims="points"), + }, + "sel", + "zarr.core.indexing.AsyncVIndex", + ), + ( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1, 3], dims="points"), + }, + "isel", + "zarr.core.indexing.AsyncVIndex", + ), + ], + ids=[ + "no-indexing-sel", + "no-indexing-isel", + "basic-int-sel", + "basic-int-isel", + "basic-slice-sel", + "basic-slice-isel", + "outer-sel", + "outer-isel", + "vectorized-sel", + "vectorized-isel", + ], + ) + async def test_indexing( + self, + store, + cls_name, + method, + indexer, + target_zarr_class, + ) -> None: + if not has_zarr_v3_async_oindex and target_zarr_class in ( + "zarr.core.indexing.AsyncOIndex", + "zarr.core.indexing.AsyncVIndex", + ): + pytest.skip( + "current version of zarr does not support orthogonal or vectorized async indexing" + ) + + if cls_name == "Variable" and method == "sel": + pytest.skip("Variable doesn't have a .sel method") + + # Each type of indexing ends up calling a different zarr indexing method + # They all use a method named .getitem, but on a different internal zarr class + target_class = _resolve_class_from_string(target_zarr_class) + method_name = "getitem" + original_method = getattr(target_class, method_name) + + with patch.object( + target_class, method_name, wraps=original_method, autospec=True + ) as mocked_meth: + xr_obj = get_xr_obj(store, cls_name) + + # TODO we're not actually testing that these indexing methods are not blocking... + result = await getattr(xr_obj, method)(**indexer).load_async() + + mocked_meth.assert_called() + mocked_meth.assert_awaited() + assert mocked_meth.call_count > 0 + + expected = getattr(xr_obj, method)(**indexer).load() + xrt.assert_identical(result, expected) + + @requires_zarr + @pytest.mark.parametrize( + ("indexer", "expected_err_msg"), + [ + pytest.param( + {"dim2": 2}, + "basic async indexing", + marks=pytest.mark.skipif( + has_zarr_v3, + reason="current version of zarr has basic async indexing", + ), + ), # tests basic indexing + pytest.param( + {"dim2": [1, 3]}, + "orthogonal async indexing", + marks=pytest.mark.skipif( + has_zarr_v3_async_oindex, + reason="current version of zarr has async orthogonal indexing", + ), + ), # tests oindexing + pytest.param( + { + "dim1": xr.Variable(data=[2, 3], dims="points"), + "dim2": xr.Variable(data=[1, 3], dims="points"), + }, + "vectorized async indexing", + marks=pytest.mark.skipif( + has_zarr_v3_async_oindex, + reason="current version of zarr has async vectorized indexing", + ), + ), # tests vindexing + ], + ) + async def test_raise_on_older_zarr_version(self, store, indexer, expected_err_msg): + """Test that trying to use async load with insufficiently new version of zarr raises a clear error""" + + ds = xr.open_zarr(store, consolidated=False, chunks=None) + + with pytest.raises(NotImplementedError, match=expected_err_msg): + await ds.isel(**indexer).load_async() diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py index 6dd75b58c6a..010987337a6 100644 --- a/xarray/tests/test_indexing.py +++ b/xarray/tests/test_indexing.py @@ -490,6 +490,25 @@ def test_sub_array(self) -> None: assert isinstance(child.array, indexing.NumpyIndexingAdapter) assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + @pytest.mark.asyncio + async def test_async_wrapper(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + await wrapped.async_get_duck_array() + assert_array_equal(wrapped, np.arange(10)) + assert isinstance(wrapped.array, indexing.NumpyIndexingAdapter) + + @pytest.mark.asyncio + async def test_async_sub_array(self) -> None: + original = indexing.LazilyIndexedArray(np.arange(10)) + wrapped = indexing.MemoryCachedArray(original) + child = wrapped[B[:5]] + assert isinstance(child, indexing.MemoryCachedArray) + await child.async_get_duck_array() + assert_array_equal(child, np.arange(5)) + assert isinstance(child.array, indexing.NumpyIndexingAdapter) + assert isinstance(wrapped.array, indexing.LazilyIndexedArray) + def test_setitem(self) -> None: original = np.arange(10) wrapped = indexing.MemoryCachedArray(original)