diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 9684f371e00..ec2f0f66115 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -632,9 +632,6 @@ def get_indexer_nd(index: pd.Index, labels, method=None, tolerance=None) -> np.n return indexer -T_PandasIndex = TypeVar("T_PandasIndex", bound="PandasIndex") - - class PandasIndex(Index): """Wrap a pandas.Index as an xarray compatible index.""" @@ -929,9 +926,7 @@ def rename(self, name_dict, dims_dict): new_dim = dims_dict.get(self.dim, self.dim) return self._replace(index, dim=new_dim) - def _copy( - self: T_PandasIndex, deep: bool = True, memo: dict[int, Any] | None = None - ) -> T_PandasIndex: + def _copy(self, deep: bool = True, memo: dict[int, Any] | None = None) -> Self: if deep: # pandas is not using the memo index = self.index.copy(deep=True) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 8e4458fb88f..9df8917075c 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -26,6 +26,7 @@ get_valid_numpy_dtype, is_duck_array, is_duck_dask_array, + is_full_slice, is_scalar, is_valid_numpy_dtype, to_0d_array, @@ -1889,7 +1890,7 @@ def __getitem__( ) -> PandasIndexingAdapter | np.ndarray: return self._index_get(indexer, "__getitem__") - def transpose(self, order) -> pd.Index: + def transpose(self, order) -> Self | pd.Index: return self.array # self.array should be always one-dimensional def _repr_inline_(self, max_width: int) -> str: @@ -2005,6 +2006,92 @@ def copy(self, deep: bool = True) -> Self: return type(self)(array, self._dtype, self.level) +class PandasIntervalIndexingAdapter(PandasIndexingAdapter): + """Wraps a pandas.IntervalIndex as a 2-dimensional coordinate array. + + When the array is not transposed, left and right interval boundaries are on + the 2nd axis, i.e., shape is (N, 2). + + """ + + __slots__ = ("_bounds_axis", "_dtype", "array") + + array: pd.IntervalIndex + _dtype: np.dtype | pd.api.extensions.ExtensionDtype + _bounds_axis: int + + def __init__( + self, + array: pd.IntervalIndex, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + transpose: bool = False, + ): + super().__init__(array, dtype=dtype) + + if transpose: + self._bounds_axis = 0 + else: + self._bounds_axis = -1 + + @property + def shape(self) -> _Shape: + if self._bounds_axis == 0: + return (2, len(self.array)) + else: + return (len(self.array), 2) + + def __array__( + self, + dtype: np.typing.DTypeLike | None = None, + /, + *, + copy: bool | None = None, + ) -> np.ndarray: + dtype = self._get_numpy_dtype(dtype) + + return np.stack( + [self.array.left, self.array.right], axis=self._bounds_axis, dtype=dtype + ) + + def get_duck_array(self) -> np.ndarray: + return np.asarray(self) + + def _index_get( + self, indexer: ExplicitIndexer, func_name: str + ) -> PandasIndexingAdapter | np.ndarray: + key: tuple | Any = indexer.tuple + + if len(key) == 1: + # unpack key so it can index a pandas.Index object (pandas.Index + # objects don't like tuples) + (key,) = key + elif len(key) == 2 and is_full_slice(key[self._bounds_axis]): + # OK to index the pandas.IntervalIndex and keep it wrapped + # (drop the bounds axis key) + key = key[self._bounds_axis + 1] + + # if length-2 or multidimensional key, convert the index to numpy array + # and index the latter + if (isinstance(key, tuple) and len(key) == 2) or getattr(key, "ndim", 0) > 1: + indexable = NumpyIndexingAdapter(np.asarray(self)) + return getattr(indexable, func_name)(indexer) + + # otherwise index the pandas IntervalIndex then re-wrap or convert the result + result = self.array[key] + + if isinstance(result, pd.IntervalIndex): + return type(self)(result, dtype=self.dtype) + elif isinstance(result, pd.Interval): + dtype = self._get_numpy_dtype() + return np.array([result.left, result.right], dtype=dtype) + else: + return self._convert_scalar(result) + + def transpose(self, order: Iterable[int]) -> Self: + transpose = tuple(order) == (1, 0) + return type(self)(self.array, dtype=self.dtype, transpose=transpose) + + class CoordinateTransformIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a CoordinateTransform as a lazy coordinate array. diff --git a/xarray/indexes/__init__.py b/xarray/indexes/__init__.py index 2cba69607f3..03e5c11b21d 100644 --- a/xarray/indexes/__init__.py +++ b/xarray/indexes/__init__.py @@ -10,10 +10,12 @@ PandasIndex, PandasMultiIndex, ) +from xarray.indexes.cf_interval_index import CFIntervalIndex from xarray.indexes.nd_point_index import NDPointIndex from xarray.indexes.range_index import RangeIndex __all__ = [ + "CFIntervalIndex", "CoordinateTransform", "CoordinateTransformIndex", "Index", diff --git a/xarray/indexes/cf_interval_index.py b/xarray/indexes/cf_interval_index.py new file mode 100644 index 00000000000..48c8e208b93 --- /dev/null +++ b/xarray/indexes/cf_interval_index.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +from collections.abc import Hashable, Iterable, Mapping, Sequence +from typing import TYPE_CHECKING, Any, cast + +import numpy as np +import pandas as pd + +from xarray.core.indexes import Index, PandasIndex +from xarray.core.indexing import IndexSelResult, PandasIntervalIndexingAdapter +from xarray.core.utils import is_full_slice +from xarray.core.variable import Variable + +if TYPE_CHECKING: + from xarray.core.types import Self + + +def check_mid_in_interval(mid_index: pd.Index, bounds_index: pd.IntervalIndex): + actual_indexer = bounds_index.get_indexer(mid_index) + expected_indexer = np.arange(mid_index.size) + if not np.array_equal(actual_indexer, expected_indexer): + raise ValueError("not all central values are in their corresponding interval") + + +class CFIntervalIndex(Index): + """Xarray index of CF-like 1-dimensional intervals. + + This index is associated with two coordinate variables like in the Climate + and Forecast (CF) conventions: + + - a 1-dimensional coordinate where each label represents an interval that is + materialized by a central value (commonly the average of its left and right + boundaries) + + - a 2-dimensional coordinate that represents the left and right boundaries + of each interval. One of the two dimensions is shared with the + aforementioned coordinate and the other one has length 2 + + Interval boundaries are wrapped in a :py:class:`pandas.IntervalIndex` and + central values are wrapped in a separate :py:class:`pandas.Index`. + + """ + + _mid_index: PandasIndex + _bounds_index: PandasIndex + _bounds_dim: str + + def __init__( + self, + mid_index: PandasIndex, + bounds_index: PandasIndex, + bounds_dim: str | None = None, + ): + assert isinstance(bounds_index.index, pd.IntervalIndex) + assert mid_index.dim == bounds_index.dim + + self._mid_index = mid_index + self._bounds_index = bounds_index + + if bounds_dim is None: + bounds_dim = "bounds" + self._bounds_dim = bounds_dim + + @classmethod + def from_variables( + cls, + variables: Mapping[Any, Variable], + *, + options: Mapping[str, Any], + ) -> Self: + if len(variables) == 2: + mid_var: Variable | None = None + bounds_var: Variable | None = None + + for name, var in variables.items(): + if var.ndim == 1: + mid_name = name + mid_var = var + elif var.ndim == 2: + bounds_name = name + bounds_var = var + + if mid_var is None or bounds_var is None: + raise ValueError( + "invalid coordinates given to CFIntervalIndex. When two coordinates are given, " + "one must be 1-dimensional (central values) and the other must be " + "2-dimensional (boundaries). Actual coordinate variables:\n" + + "\n".join(str(var) for var in variables.values()) + ) + + if mid_var.dims[0] == bounds_var.dims[0]: + dim, bounds_dim = bounds_var.dims + elif mid_var.dims[0] == bounds_var.dims[1]: + bounds_dim, dim = bounds_var.dims + else: + raise ValueError( + "dimension names mismatch between " + f"the central coordinate {mid_name!r} {mid_var.dims!r} and " + f"the boundary coordinate {bounds_name!r} {bounds_var.dims!r} " + "given to CFIntervalIndex" + ) + + if bounds_var.sizes[bounds_dim] != 2: + raise ValueError( + "invalid shape for the boundary coordinate given to CFIntervalIndex " + f"(expected dimension {bounds_dim!r} of size 2)" + ) + + pd_mid_index = pd.Index(mid_var.values, name=mid_name) + mid_index = PandasIndex(pd_mid_index, dim, coord_dtype=mid_var.dtype) + + left, right = bounds_var.transpose(..., dim).values.tolist() + # TODO: make closed configurable + pd_bounds_index = pd.IntervalIndex.from_arrays( + left, right, name=bounds_name + ) + bounds_index = PandasIndex( + pd_bounds_index, dim, coord_dtype=bounds_var.dtype + ) + + check_mid_in_interval(pd_mid_index, pd_bounds_index) + + elif len(variables) == 1: + # TODO: allow setting the index from one variable? Perhaps in this fallback order: + # - check if the coordinate wraps a pd.IntervalIndex + # - look after the CF `bounds` attribute + # - guess bounds like cf_xarray's add_bounds + raise ValueError( + "Setting a CFIntervalIndex from one coordinate is not yet supported" + ) + else: + raise ValueError("Too many coordinate variables given to CFIntervalIndex") + + return cls(mid_index, bounds_index, bounds_dim=str(bounds_dim)) + + @classmethod + def concat( + cls, + indexes: Sequence[CFIntervalIndex], + dim: Hashable, + positions: Iterable[Iterable[int]] | None = None, + ) -> CFIntervalIndex: + new_mid_index = PandasIndex.concat( + [idx.mid_index for idx in indexes], dim, positions=positions + ) + new_bounds_index = PandasIndex.concat( + [idx.bounds_index for idx in indexes], dim, positions=positions + ) + + if indexes: + bounds_dim = indexes[0].bounds_dim + # TODO: check whether this may actually happen or concat fails early during alignment + if any(idx._bounds_dim != bounds_dim for idx in indexes): + raise ValueError( + f"Cannot concatenate along dimension {dim!r} indexes with different " + "boundary coordinate or dimension names" + ) + else: + bounds_dim = "bounds" + + return cls(new_mid_index, new_bounds_index, str(bounds_dim)) + + @property + def mid_index(self) -> PandasIndex: + return self._mid_index + + @property + def bounds_index(self) -> PandasIndex: + return self._bounds_index + + @property + def dim(self) -> Hashable: + return self.mid_index.dim + + @property + def bounds_dim(self) -> Hashable: + return self._bounds_dim + + def create_variables( + self, variables: Mapping[Any, Variable] | None = None + ) -> dict[Any, Variable]: + new_variables = self.mid_index.create_variables(variables) + + # boundary variable (we cannot just defer to self.bounds_index.create_variables()) + bounds_pd_index = cast(pd.IntervalIndex, self.bounds_index.index) + bounds_varname = bounds_pd_index.name + attrs: Mapping[Hashable, Any] | None + encoding: Mapping[Hashable, Any] | None + + if variables is not None and bounds_varname in variables: + var = variables[bounds_varname] + attrs = var.attrs + encoding = var.encoding + else: + attrs = None + encoding = None + + # TODO: do we want to preserve the original dimension order for the boundary coordinate? + # (using CF-compliant order below) + data = PandasIntervalIndexingAdapter( + bounds_pd_index, dtype=self.bounds_index.coord_dtype + ) + new_variables[bounds_varname] = Variable( + (self.dim, self.bounds_dim), data, attrs=attrs, encoding=encoding + ) + + return new_variables + + def should_add_coord_to_array( + self, + name: Hashable, + var: Variable, + dims: set[Hashable], + ) -> bool: + # add both the central and boundary coordinates if the dimension + # that they both share is present in the array dimensions + return self.dim in dims + + def equals(self, other: Index) -> bool: + if not isinstance(other, CFIntervalIndex): + return False + + return self.mid_index.equals(other.mid_index) and self.bounds_index.equals( + other.bounds_index + ) + + def join(self, other: Self, how: str = "inner") -> Self: + joined_mid_index = self.mid_index.join(other.mid_index, how=how) + joined_bounds_index = self.bounds_index.join(other.bounds_index, how=how) + + assert isinstance(joined_bounds_index, pd.IntervalIndex) + check_mid_in_interval( + joined_mid_index.index, cast(pd.IntervalIndex, joined_bounds_index.index) + ) + + return type(self)(joined_mid_index, joined_bounds_index, self.bounds_dim) + + def reindex_like( + self, other: Self, method=None, tolerance=None + ) -> dict[Hashable, Any]: + mid_indexers = self.mid_index.reindex_like( + other.mid_index, method=method, tolerance=tolerance + ) + bounds_indexers = self.mid_index.reindex_like( + other.bounds_index, method=method, tolerance=tolerance + ) + + if not np.array_equal(mid_indexers[self.dim], bounds_indexers[self.dim]): + raise ValueError( + f"conflicting reindexing of central values and intervals along dimension {self.dim!r}" + ) + + return mid_indexers + + def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult: + bounds_coord_name = self.bounds_index.index.name + if bounds_coord_name in labels: + raise ValueError( + "CFIntervalIndex doesn't support label-based selection " + f"using the boundary coordinate {bounds_coord_name!r}" + ) + + return self.bounds_index.sel(labels, **kwargs) + + def isel( + self, indexers: Mapping[Any, int | slice | np.ndarray | Variable] + ) -> Self | None: + indexers = dict(indexers) + + if self.bounds_dim in indexers: + if is_full_slice(indexers[self._bounds_dim]): + # prevent errors raised when calling isel on the underlying PandasIndex objects + indexers.pop(self.bounds_dim) + if self.dim not in indexers: + indexers[self.dim] = slice(None) + else: + # drop the index when selecting on the bounds dimension + return None + + new_mid_index = self.mid_index.isel(indexers) + new_bounds_index = self.bounds_index.isel(indexers) + + if new_mid_index is None or new_bounds_index is None: + return None + else: + return type(self)(new_mid_index, new_bounds_index, str(self.bounds_dim)) + + def roll(self, shifts: Mapping[Any, int]) -> Self | None: + new_mid_index = self.mid_index.roll(shifts) + new_bounds_index = self.bounds_index.roll(shifts) + + return type(self)(new_mid_index, new_bounds_index, self._bounds_dim) + + def rename( + self, + name_dict: Mapping[Any, Hashable], + dims_dict: Mapping[Any, Hashable], + ) -> Self: + new_mid_index = self.mid_index.rename(name_dict, dims_dict) + new_bounds_index = self.bounds_index.rename(name_dict, dims_dict) + + bounds_dim = dims_dict.get(self.bounds_dim, self.bounds_dim) + + return type(self)(new_mid_index, new_bounds_index, str(bounds_dim)) + + def __repr__(self) -> str: + text = "CFIntervalIndex\n" + text += f"- central values:\n{self.mid_index!r}\n" + text += f"- boundaries:\n{self.bounds_index!r}\n" + return text