Skip to content

Commit f4e39da

Browse files
committed
(fix): some types
1 parent e73b82f commit f4e39da

File tree

2 files changed

+23
-13
lines changed

2 files changed

+23
-13
lines changed

xarray/core/dtypes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4-
from typing import Any
4+
from typing import TYPE_CHECKING
55

66
import numpy as np
77
from pandas.api.types import is_extension_array_dtype
@@ -10,6 +10,11 @@
1010
from xarray.compat.npcompat import HAS_STRING_DTYPE
1111
from xarray.core import utils
1212

13+
if TYPE_CHECKING:
14+
from typing import Any
15+
16+
from pandas.api.extensions import ExtensionDtype
17+
1318
# Use as a sentinel value to indicate a dtype appropriate NA value.
1419
NA = utils.ReprObject("<NA>")
1520

@@ -48,7 +53,7 @@ def __eq__(self, other):
4853
)
4954

5055

51-
def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
56+
def maybe_promote(dtype: np.dtype | ExtensionDtype) -> tuple[np.dtype, Any]:
5257
"""Simpler equivalent of pandas.core.common._maybe_promote
5358
5459
Parameters

xarray/core/extension_array.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __extension_duck_array__astype(
6565
subok: bool = True,
6666
copy: bool = True,
6767
device: str | None = None,
68-
) -> T_ExtensionArray:
68+
) -> ExtensionArray:
6969
if (
7070
not (
7171
is_extension_array_dtype(array_or_scalar) or is_extension_array_dtype(dtype)
@@ -82,7 +82,7 @@ def __extension_duck_array__astype(
8282
@implements(np.asarray)
8383
def __extension_duck_array__asarray(
8484
array_or_scalar: np.typing.ArrayLike, dtype: DTypeLikeSave = None
85-
) -> T_ExtensionArray:
85+
) -> ExtensionArray:
8686
if not is_extension_array_dtype(dtype):
8787
return NotImplemented
8888

@@ -91,9 +91,9 @@ def __extension_duck_array__asarray(
9191

9292
def as_extension_array(
9393
array_or_scalar: np.typing.ArrayLike, dtype: ExtensionDtype, copy: bool = False
94-
) -> T_ExtensionArray:
94+
) -> ExtensionArray:
9595
if is_scalar(array_or_scalar):
96-
return dtype.construct_array_type()._from_sequence(
96+
return dtype.construct_array_type()._from_sequence( # type: ignore[attr-defined]
9797
[array_or_scalar], dtype=dtype
9898
)
9999
else:
@@ -104,14 +104,17 @@ def as_extension_array(
104104
def __extension_duck_array__result_type(
105105
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
106106
) -> DtypeObj:
107-
extension_arrays_and_dtypes = [
108-
x for x in arrays_and_dtypes if is_extension_array_dtype(x)
107+
extension_arrays_and_dtypes: list[ExtensionDtype | ExtensionArray] = [
108+
x
109+
for x in arrays_and_dtypes
110+
if is_extension_array_dtype(x) # type: ignore[arg-type, misc]
109111
]
110112
if not extension_arrays_and_dtypes:
111113
return NotImplemented
112114

113115
ea_dtypes: list[ExtensionDtype] = [
114-
getattr(x, "dtype", x) for x in extension_arrays_and_dtypes
116+
getattr(x, "dtype", cast(ExtensionDtype, x))
117+
for x in extension_arrays_and_dtypes
115118
]
116119
scalars: list[Scalar] = [
117120
x for x in arrays_and_dtypes if is_scalar(x) and x not in {pd.NA, np.nan}
@@ -122,15 +125,17 @@ def __extension_duck_array__result_type(
122125
other_stuff = [
123126
x
124127
for x in arrays_and_dtypes
125-
if not is_extension_array_dtype(x) and not is_scalar(x)
128+
if not is_extension_array_dtype(x) and not is_scalar(x) # type: ignore[arg-type, misc]
126129
]
127130
# We implement one special case: when possible, preserve Categoricals (avoid promoting
128131
# to object) by merging the categories of all given Categoricals + scalars + NA.
129132
# Ideally this could be upstreamed into pandas find_result_type / find_common_type.
130133
if not other_stuff and all(
131134
isinstance(x, pd.CategoricalDtype) and not x.ordered for x in ea_dtypes
132135
):
133-
return union_unordered_categorical_and_scalar(ea_dtypes, scalars)
136+
return union_unordered_categorical_and_scalar(
137+
cast(list[pd.CategoricalDtype], ea_dtypes), scalars
138+
)
134139
if not other_stuff and all(
135140
isinstance(x, type(ea_type := ea_dtypes[0])) for x in ea_dtypes
136141
):
@@ -146,7 +151,7 @@ def union_unordered_categorical_and_scalar(
146151
scalars = [x for x in scalars if x is not pd.CategoricalDtype.na_value]
147152
all_categories = set().union(*(x.categories for x in categorical_dtypes))
148153
all_categories = all_categories.union(scalars)
149-
return pd.CategoricalDtype(categories=all_categories)
154+
return pd.CategoricalDtype(categories=list(all_categories))
150155

151156

152157
@implements(np.broadcast_to)
@@ -174,7 +179,7 @@ def __extension_duck_array__where(
174179
x: T_ExtensionArray,
175180
y: T_ExtensionArray | np.ArrayLike,
176181
) -> T_ExtensionArray:
177-
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array)
182+
return cast(T_ExtensionArray, pd.Series(x).where(condition, y).array) # type: ignore[arg-type]
178183

179184

180185
def _replace_duck(args, replacer: Callable[[PandasExtensionArray]]) -> list:

0 commit comments

Comments
 (0)