Skip to content

(fix): no fill_value on reindex #10304

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions properties/test_pandas_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,21 @@ def test_roundtrip_pandas_dataframe_datetime(df) -> None:
xr.testing.assert_identical(dataset, roundtripped.to_xarray())


def test_roundtrip_1d_pandas_extension_array() -> None:
df = pd.DataFrame({"cat": pd.Categorical(["a", "b", "c"])})
arr = xr.Dataset.from_dataframe(df)["cat"]
@pytest.mark.parametrize(
"extension_array",
[
pd.Categorical(["a", "b", "c"]),
pd.array([1, 2, 3], dtype="int64"),
pd.array(["a", "b", "c"], dtype="string"),
pd.arrays.IntervalArray(
[pd.Interval(0, 1), pd.Interval(1, 5), pd.Interval(2, 6)]
),
],
)
def test_roundtrip_1d_pandas_extension_array(extension_array) -> None:
df = pd.DataFrame({"arr": extension_array})
arr = xr.Dataset.from_dataframe(df)["arr"]
roundtripped = arr.to_pandas()
assert (df["cat"] == roundtripped).all()
assert df["cat"].dtype == roundtripped.dtype
assert (df["arr"] == roundtripped).all()
assert df["arr"].dtype == roundtripped.dtype
xr.testing.assert_identical(arr, roundtripped.to_xarray())
5 changes: 4 additions & 1 deletion xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

import numpy as np
import pandas as pd
from pandas.api.types import is_extension_array_dtype

from xarray.compat import array_api_compat, npcompat
Expand Down Expand Up @@ -63,7 +64,9 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
if pd.api.types.is_extension_array_dtype(dtype):
return dtype, pd.NA
elif HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
# for now, we always promote string dtypes to object for consistency with existing behavior
# TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
dtype_ = object
Expand Down
30 changes: 23 additions & 7 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@

if xp == np:
# numpy currently doesn't have a astype:
return data.astype(dtype, **kwargs)

Check warning on line 254 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 254 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / macos-latest py3.10

invalid value encountered in cast

Check warning on line 254 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast

Check warning on line 254 in xarray/core/duck_array_ops.py

View workflow job for this annotation

GitHub Actions / windows-latest py3.10

invalid value encountered in cast
return xp.astype(data, dtype, **kwargs)


Expand All @@ -273,18 +273,35 @@
extension_array_types = [
x.dtype for x in scalars_or_arrays if is_extension_array_dtype(x)
]
non_nans = [x for x in scalars_or_arrays if not isna(x)]
if len(extension_array_types) == len(non_nans) and all(
non_nans_or_scalar = [
x for x in scalars_or_arrays if not (isna(x) or np.isscalar(x))
]
if len(extension_array_types) == len(non_nans_or_scalar) and all(
isinstance(x, type(extension_array_types[0])) for x in extension_array_types
):
return [
# Get the extension array class of the first element, guaranteed to be the same
# as the others thanks to the anove check.
extension_array_class = type(
non_nans_or_scalar[0].array
if isinstance(non_nans_or_scalar[0], PandasExtensionArray)
else non_nans_or_scalar[0]
)
# Cast scalars/nans to extension array class
arrays_with_nan_to_sequence = [
x
if not isna(x)
else PandasExtensionArray(
type(non_nans[0].array)._from_sequence([x], dtype=non_nans[0].dtype)
if not (isna(x) or np.isscalar(x))
else extension_array_class._from_sequence(
[x], dtype=non_nans_or_scalar[0].dtype
)
for x in scalars_or_arrays
]
# Wrap the output if necessary
return [
PandasExtensionArray(x)
if not isinstance(x, PandasExtensionArray)
else x
for x in arrays_with_nan_to_sequence
]
raise ValueError(
f"Cannot cast values to shared type, found values: {scalars_or_arrays}"
)
Expand Down Expand Up @@ -407,7 +424,6 @@
condition = asarray(condition, dtype=dtype, xp=xp)
else:
condition = astype(condition, dtype=dtype, xp=xp)

return xp.where(condition, *as_shared_dtype([x, y], xp=xp))


Expand Down
10 changes: 10 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,16 @@ def create_test_data(
)
),
)
obj["var5"] = (
"dim1",
pd.array(
rs.integers(1, 10, size=dim_sizes[0]).tolist(), dtype=pd.Int64Dtype()
),
)
obj["var6"] = (
"dim1",
pd.array(list(string.ascii_lowercase[: dim_sizes[0]]), dtype="string"),
)
if dim_sizes == _DEFAULT_TEST_DIM_SIZES:
numbers_values = np.array([0, 1, 2, 0, 0, 1, 1, 2, 2, 3], dtype="int64")
else:
Expand Down
21 changes: 11 additions & 10 deletions xarray/tests/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,20 @@ def test_concat_missing_var() -> None:
assert_identical(actual, expected)


def test_concat_categorical() -> None:
def test_concat_extension_array() -> None:
data1 = create_test_data(use_extension_array=True)
data2 = create_test_data(use_extension_array=True)
concatenated = concat([data1, data2], dim="dim1")
assert (
concatenated["var4"]
== type(data2["var4"].variable.data)._concat_same_type(
[
data1["var4"].variable.data,
data2["var4"].variable.data,
]
)
).all()
for var in ["var4", "var5"]:
assert (
concatenated[var]
== type(data2[var].variable.data)._concat_same_type(
[
data1[var].variable.data,
data2[var].variable.data,
]
)
).all()


def test_concat_missing_multiple_consecutive_var() -> None:
Expand Down
18 changes: 18 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,24 @@ def test_propagate_attrs(self, func) -> None:
with set_options(keep_attrs=True):
assert func(da).attrs == da.attrs

def test_fillna_extension_array_int(self) -> None:
srs: pd.Series = pd.Series(
index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we aren't ready to support these yet. Can we instead autoconvert anything with pd.NA to float and convert to the numpy dtype otherwise?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcherian I think we should keep the two issues separate. It sounds like @richard-berg will have a PR that cleans this behavior up (i.e., convert float/int/string types from NumpyExtensionArray, maintain the rest or something close). For now I would change the tests so that his life is easier, but at the end of the day, I think this PR is needed regardless since there is nothing specifically referencing numeric/string types. You can reproduce the issue in #10301 with any kind of extension array type:

import pandas as pd, numpy as np, xarray as xr
cat = pd.Series(pd.Categorical(["a", "b", "c", "b"]))
cat.to_xarray().reindex(index=[1, -1])
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[3], line 1
----> 1 cat.to_xarray().reindex(index=[1, -1])

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/core/dataarray.py:2218, in DataArray.reindex(self, indexers, method, tolerance, copy, fill_value, **indexers_kwargs)
   2145 """Conform this object onto the indexes of another object, filling in
   2146 missing values with ``fill_value``. The default fill value is NaN.
   2147 
   (...)   2215 align
   2216 """
   2217 indexers = utils.either_dict_or_kwargs(indexers, indexers_kwargs, "reindex")
-> 2218 return alignment.reindex(
   2219     self,
   2220     indexers=indexers,
   2221     method=method,
   2222     tolerance=tolerance,
   2223     copy=copy,
   2224     fill_value=fill_value,
   2225 )

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/structure/alignment.py:984, in reindex(obj, indexers, method, tolerance, copy, fill_value, sparse, exclude_vars)
    965 # TODO: (benbovy - explicit indexes): uncomment?
    966 # --> from reindex docstrings: "any mismatched dimension is simply ignored"
    967 # bad_keys = [k for k in indexers if k not in obj._indexes and k not in obj.dims]
   (...)    971 #         "or unindexed dimension in the object to reindex"
    972 #     )
    974 aligner = Aligner(
    975     (obj,),
    976     indexes=indexers,
   (...)    982     exclude_vars=exclude_vars,
    983 )
--> 984 aligner.align()
    985 return aligner.results[0]

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/structure/alignment.py:567, in Aligner.align(self)
    565     self.results = self.objects
    566 else:
--> 567     self.reindex_all()

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/structure/alignment.py:543, in Aligner.reindex_all(self)
    542 def reindex_all(self) -> None:
--> 543     self.results = tuple(
    544         self._reindex_one(obj, matching_indexes)
    545         for obj, matching_indexes in zip(
    546             self.objects, self.objects_matching_indexes, strict=True
    547         )
    548     )

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/structure/alignment.py:544, in <genexpr>(.0)
    542 def reindex_all(self) -> None:
    543     self.results = tuple(
--> 544         self._reindex_one(obj, matching_indexes)
    545         for obj, matching_indexes in zip(
    546             self.objects, self.objects_matching_indexes, strict=True
    547         )
    548     )

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/structure/alignment.py:532, in Aligner._reindex_one(self, obj, matching_indexes)
    529 new_indexes, new_variables = self._get_indexes_and_vars(obj, matching_indexes)
    530 dim_pos_indexers = self._get_dim_pos_indexers(matching_indexes)
--> 532 return obj._reindex_callback(
    533     self,
    534     dim_pos_indexers,
    535     new_variables,
    536     new_indexes,
    537     self.fill_value,
    538     self.exclude_dims,
    539     self.exclude_vars,
    540 )

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/core/dataarray.py:1934, in DataArray._reindex_callback(self, aligner, dim_pos_indexers, variables, indexes, fill_value, exclude_dims, exclude_vars)
   1931         fill_value[_THIS_ARRAY] = value
   1933 ds = self._to_temp_dataset()
-> 1934 reindexed = ds._reindex_callback(
   1935     aligner,
   1936     dim_pos_indexers,
   1937     variables,
   1938     indexes,
   1939     fill_value,
   1940     exclude_dims,
   1941     exclude_vars,
   1942 )
   1944 da = self._from_temp_dataset(reindexed)
   1945 da.encoding = self.encoding

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/core/dataset.py:3277, in Dataset._reindex_callback(self, aligner, dim_pos_indexers, variables, indexes, fill_value, exclude_dims, exclude_vars)
   3271 else:
   3272     to_reindex = {
   3273         k: v
   3274         for k, v in self.variables.items()
   3275         if k not in variables and k not in exclude_vars
   3276     }
-> 3277     reindexed_vars = alignment.reindex_variables(
   3278         to_reindex,
   3279         dim_pos_indexers,
   3280         copy=aligner.copy,
   3281         fill_value=fill_value,
   3282         sparse=aligner.sparse,
   3283     )
   3284     new_variables.update(reindexed_vars)
   3285     new_coord_names = self._coord_names | set(new_indexes)

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/structure/alignment.py:83, in reindex_variables(variables, dim_pos_indexers, copy, fill_value, sparse)
     80 needs_masking = any(d in masked_dims for d in var.dims)
     82 if needs_masking:
---> 83     new_var = var._getitem_with_mask(indxr, fill_value=fill_value_)
     84 elif all(is_full_slice(k) for k in indxr):
     85     # no reindexing necessary
     86     # here we need to manually deal with copying data, since
     87     # we neither created a new ndarray nor used fancy indexing
     88     new_var = var.copy(deep=copy)

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/core/variable.py:798, in Variable._getitem_with_mask(self, key, fill_value)
    789 # TODO(shoyer): expose this method in public API somewhere (isel?) and
    790 # use it for reindex.
    791 # TODO(shoyer): add a sanity check that all other integers are
   (...)    794 # that is actually indexed rather than mapping it to the last value
    795 # along each axis.
    797 if fill_value is dtypes.NA:
--> 798     fill_value = dtypes.get_fill_value(self.dtype)
    800 dims, indexer, new_order = self._broadcast_indexes(key)
    802 if self.size:

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/core/dtypes.py:112, in get_fill_value(dtype)
    101 def get_fill_value(dtype):
    102     """Return an appropriate fill value for this dtype.
    103 
    104     Parameters
   (...)    110     fill_value : Missing value corresponding to this dtype.
    111     """
--> 112     _, fill_value = maybe_promote(dtype)
    113     return fill_value

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/xarray/core/dtypes.py:66, in maybe_promote(dtype)
     64 dtype_: np.typing.DTypeLike
     65 fill_value: Any
---> 66 if HAS_STRING_DTYPE and np.issubdtype(dtype, np.dtypes.StringDType()):
     67     # for now, we always promote string dtypes to object for consistency with existing behavior
     68     # TODO: refactor this once we have a better way to handle numpy vlen-string dtypes
     69     dtype_ = object
     70     fill_value = np.nan

File ~/Projects/Theis/anndata/venv/lib/python3.12/site-packages/numpy/_core/numerictypes.py:530, in issubdtype(arg1, arg2)
    473 r"""
    474 Returns True if first argument is a typecode lower/equal in type hierarchy.
    475 
   (...)    527 
    528 """
    529 if not issubclass_(arg1, generic):
--> 530     arg1 = dtype(arg1).type
    531 if not issubclass_(arg2, generic):
    532     arg2 = dtype(arg2).type

TypeError: Cannot interpret 'CategoricalDtype(categories=['a', 'b', 'c'], ordered=False, categories_dtype=object)' as a data type

If @richard-berg does not get permission to handle this sort of thing, I can make a PR and hopefully he can review.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a small fix here to use arrow types as the "other" extension array type to ensure some level of better coverage, which indeed revealed a small bug. I will update the reindexing tests as well to use them.

Copy link
Contributor Author

@ilan-gold ilan-gold May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two fixes:

  1. https://github.com/pydata/xarray/pull/10304/files#diff-c803294f5216cbbdffa30f0b0c9f16a7e39855d4dd309c88d654bc317a78adc0R54-R64 + https://github.com/pydata/xarray/pull/10304/files#diff-5803161603b7d9c554e24313b463cc6f826d644128cb9e58755de4e2a0ac7467R179-R202 We didn't have a reshape function registered so it was falling back to numpy which was problematic for repr once it was added here because repr didn't know about extension arrays. That has been resolved.
  2. The arrow dtype indexing implementation lacked handling of tuples: https://github.com/pydata/xarray/pull/10304/files#diff-c803294f5216cbbdffa30f0b0c9f16a7e39855d4dd309c88d654bc317a78adc0R124-R127

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This issue is a lot deeper than I realized: pandas-dev/pandas#61433, actually for the indexing issue with arrow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, it looks like this doesn't have an impact here, as it turns out. It still seems like we're getting away with murder a bit because it is possible to yield 0-dimensional arrow arrays from some data types and not others. When it is possible, that gels with what xarray expects, but when it's not possible, we have problems. We probably need to look into a solution for this. Maybe making __getitem__ polymorphic, but fixing this issue + adding the test cases opened up a few different cans of worms here, so I will be able to track that issue and then can work on a fix here separately.

The broken test is similarly related to unexpected behavior around arrow dtypes again...will look into that as well

)
da = srs.to_xarray()
filled = da.fillna(0)
assert filled.dtype == pd.Int64Dtype()
assert (filled.values == np.array([0, 1, 1])).all()

def test_dropna_extension_array_int(self) -> None:
srs: pd.Series = pd.Series(
index=np.array([1, 2, 3]), data=pd.array([pd.NA, 1, 1])
)
da = srs.to_xarray()
filled = da.dropna("index")
assert filled.dtype == pd.Int64Dtype()
assert (filled.values == np.array([1, 1])).all()

def test_fillna(self) -> None:
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")
actual = a.fillna(-1)
Expand Down
53 changes: 40 additions & 13 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ def test_repr(self) -> None:
var2 (dim1, dim2) float64 576B 0.953 1.52 1.704 ... 0.1347 -0.6423
var3 (dim3, dim1) float64 640B 0.4107 0.9941 0.1665 ... 0.716 1.555
var4 (dim1) category 32B 'b' 'c' 'b' 'a' 'c' 'a' 'c' 'a'
var5 (dim1) Int64 72B 5 9 7 2 6 2 8 1
var6 (dim1) string 64B 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h'
Attributes:
foo: bar""".format(
data["dim3"].dtype,
Expand Down Expand Up @@ -1827,25 +1829,36 @@ def test_categorical_index_reindex(self) -> None:
actual = ds.reindex(cat=["foo"])["cat"].values
assert (actual == np.array(["foo"])).all()

@pytest.mark.parametrize("fill_value", [np.nan, pd.NA])
def test_extensionarray_negative_reindex(self, fill_value) -> None:
cat = pd.Categorical(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux", "quux", "corge"],
)
@pytest.mark.parametrize("fill_value", [np.nan, pd.NA, None])
@pytest.mark.parametrize(
"extension_array",
[
pd.Categorical(
["foo", "bar", "baz"],
categories=["foo", "bar", "baz", "qux"],
),
pd.array([1, 2, 3], dtype=pd.Int32Dtype()),
pd.array(["a", "b", "c"], dtype="string"),
],
)
def test_extensionarray_negative_reindex(self, fill_value, extension_array) -> None:
ds = xr.Dataset(
{"cat": ("index", cat)},
{"arr": ("index", extension_array)},
coords={"index": ("index", np.arange(3))},
)
kwargs = {}
if fill_value is not None:
kwargs["fill_value"] = fill_value
reindexed_cat = cast(
pd.api.extensions.ExtensionArray,
(
ds.reindex(index=[-1, 1, 1], fill_value=fill_value)["cat"]
.to_pandas()
.values
),
(ds.reindex(index=[-1, 1, 1], **kwargs)["arr"].to_pandas().values),
)
assert reindexed_cat.equals(pd.array([pd.NA, "bar", "bar"], dtype=cat.dtype)) # type: ignore[attr-defined]
assert reindexed_cat.equals(
pd.array(
[pd.NA, extension_array[1], extension_array[1]],
dtype=extension_array.dtype,
)
) # type: ignore[attr-defined]

def test_extension_array_reindex_same(self) -> None:
series = pd.Series([1, 2, pd.NA, 3], dtype=pd.Int32Dtype())
Expand Down Expand Up @@ -5443,6 +5456,20 @@ def test_dropna(self) -> None:
with pytest.raises(TypeError, match=r"must specify how or thresh"):
ds.dropna("a", how=None) # type: ignore[arg-type]

def test_fillna_extension_array_int(self) -> None:
srs = pd.DataFrame({"data": pd.array([pd.NA, 1, 1])}, index=np.array([1, 2, 3]))
ds = srs.to_xarray()
filled = ds.fillna(0)
assert filled["data"].dtype == pd.Int64Dtype()
assert (filled["data"].values == np.array([0, 1, 1])).all()

def test_dropna_extension_array_int(self) -> None:
srs = pd.DataFrame({"data": pd.array([pd.NA, 1, 1])}, index=np.array([1, 2, 3]))
ds = srs.to_xarray()
dropped = ds.dropna("index")
assert dropped["data"].dtype == pd.Int64Dtype()
assert (dropped["data"].values == np.array([1, 1])).all()

def test_fillna(self) -> None:
ds = Dataset({"a": ("x", [np.nan, 1, np.nan, 3])}, {"x": [0, 1, 2, 3]})

Expand Down
Loading