Skip to content

Commit abad670

Browse files
authored
Add Dataset.dtypes property (#6706)
* add Dataset.dtypes property * add Dataset.dtypes to whats-new * add Dataset.dtypes to api * fix typo * fix mypy issue * dtypes property for DataArrayCoordinates, DataVariables and DatasetCoordinates * update whats new
1 parent a1b0523 commit abad670

File tree

7 files changed

+148
-32
lines changed

7 files changed

+148
-32
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Attributes
6161

6262
Dataset.dims
6363
Dataset.sizes
64+
Dataset.dtypes
6465
Dataset.data_vars
6566
Dataset.coords
6667
Dataset.attrs

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ v2022.06.0 (unreleased)
2222
New Features
2323
~~~~~~~~~~~~
2424

25+
- Add :py:meth:`Dataset.dtypes`, :py:meth:`DatasetCoordinates.dtypes`,
26+
:py:meth:`DataArrayCoordinates.dtypes` properties: Mapping from variable names to dtypes.
27+
(:pull:`6706`)
28+
By `Michael Niklas <https://github.com/headtr1ck>`_.
2529

2630
Deprecations
2731
~~~~~~~~~~~~

xarray/core/coordinates.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ def _names(self) -> set[Hashable]:
3838
def dims(self) -> Mapping[Hashable, int] | tuple[Hashable, ...]:
3939
raise NotImplementedError()
4040

41+
@property
42+
def dtypes(self) -> Frozen[Hashable, np.dtype]:
43+
raise NotImplementedError()
44+
4145
@property
4246
def indexes(self) -> Indexes[pd.Index]:
4347
return self._data.indexes # type: ignore[attr-defined]
@@ -242,6 +246,24 @@ def _names(self) -> set[Hashable]:
242246
def dims(self) -> Mapping[Hashable, int]:
243247
return self._data.dims
244248

249+
@property
250+
def dtypes(self) -> Frozen[Hashable, np.dtype]:
251+
"""Mapping from coordinate names to dtypes.
252+
253+
Cannot be modified directly, but is updated when adding new variables.
254+
255+
See Also
256+
--------
257+
Dataset.dtypes
258+
"""
259+
return Frozen(
260+
{
261+
n: v.dtype
262+
for n, v in self._data._variables.items()
263+
if n in self._data._coord_names
264+
}
265+
)
266+
245267
@property
246268
def variables(self) -> Mapping[Hashable, Variable]:
247269
return Frozen(
@@ -313,6 +335,18 @@ def __init__(self, dataarray: DataArray):
313335
def dims(self) -> tuple[Hashable, ...]:
314336
return self._data.dims
315337

338+
@property
339+
def dtypes(self) -> Frozen[Hashable, np.dtype]:
340+
"""Mapping from coordinate names to dtypes.
341+
342+
Cannot be modified directly, but is updated when adding new variables.
343+
344+
See Also
345+
--------
346+
DataArray.dtype
347+
"""
348+
return Frozen({n: v.dtype for n, v in self._data._coords.items()})
349+
316350
@property
317351
def _names(self) -> set[Hashable]:
318352
return set(self._data._coords)

xarray/core/dataarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@
102102

103103
def _infer_coords_and_dims(
104104
shape, coords, dims
105-
) -> tuple[dict[Any, Variable], tuple[Hashable, ...]]:
105+
) -> tuple[dict[Hashable, Variable], tuple[Hashable, ...]]:
106106
"""All the logic for creating a new DataArray"""
107107

108108
if (
@@ -140,7 +140,7 @@ def _infer_coords_and_dims(
140140
if not isinstance(d, str):
141141
raise TypeError(f"dimension {d} is not a string")
142142

143-
new_coords: dict[Any, Variable] = {}
143+
new_coords: dict[Hashable, Variable] = {}
144144

145145
if utils.is_dict_like(coords):
146146
for k, v in coords.items():

xarray/core/dataset.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
from ..coding.calendar_ops import convert_calendar, interp_calendar
3636
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
3737
from ..plot.dataset_plot import _Dataset_PlotMethods
38+
from . import alignment
39+
from . import dtypes as xrdtypes
3840
from . import (
39-
alignment,
40-
dtypes,
4141
duck_array_ops,
4242
formatting,
4343
formatting_html,
@@ -385,6 +385,18 @@ def variables(self) -> Mapping[Hashable, Variable]:
385385
all_variables = self._dataset.variables
386386
return Frozen({k: all_variables[k] for k in self})
387387

388+
@property
389+
def dtypes(self) -> Frozen[Hashable, np.dtype]:
390+
"""Mapping from data variable names to dtypes.
391+
392+
Cannot be modified directly, but is updated when adding new variables.
393+
394+
See Also
395+
--------
396+
Dataset.dtype
397+
"""
398+
return self._dataset.dtypes
399+
388400
def _ipython_key_completions_(self):
389401
"""Provide method for the key-autocompletions in IPython."""
390402
return [
@@ -677,6 +689,24 @@ def sizes(self) -> Frozen[Hashable, int]:
677689
"""
678690
return self.dims
679691

692+
@property
693+
def dtypes(self) -> Frozen[Hashable, np.dtype]:
694+
"""Mapping from data variable names to dtypes.
695+
696+
Cannot be modified directly, but is updated when adding new variables.
697+
698+
See Also
699+
--------
700+
DataArray.dtype
701+
"""
702+
return Frozen(
703+
{
704+
n: v.dtype
705+
for n, v in self._variables.items()
706+
if n not in self._coord_names
707+
}
708+
)
709+
680710
def load(self: T_Dataset, **kwargs) -> T_Dataset:
681711
"""Manually trigger loading and/or computation of this dataset's data
682712
from disk or a remote source into memory and return this dataset.
@@ -2792,7 +2822,7 @@ def reindex_like(
27922822
method: ReindexMethodOptions = None,
27932823
tolerance: int | float | Iterable[int | float] | None = None,
27942824
copy: bool = True,
2795-
fill_value: Any = dtypes.NA,
2825+
fill_value: Any = xrdtypes.NA,
27962826
) -> T_Dataset:
27972827
"""Conform this object onto the indexes of another object, filling in
27982828
missing values with ``fill_value``. The default fill value is NaN.
@@ -2858,7 +2888,7 @@ def reindex(
28582888
method: ReindexMethodOptions = None,
28592889
tolerance: int | float | Iterable[int | float] | None = None,
28602890
copy: bool = True,
2861-
fill_value: Any = dtypes.NA,
2891+
fill_value: Any = xrdtypes.NA,
28622892
**indexers_kwargs: Any,
28632893
) -> T_Dataset:
28642894
"""Conform this object onto a new set of indexes, filling in
@@ -3074,7 +3104,7 @@ def _reindex(
30743104
method: str = None,
30753105
tolerance: int | float | Iterable[int | float] | None = None,
30763106
copy: bool = True,
3077-
fill_value: Any = dtypes.NA,
3107+
fill_value: Any = xrdtypes.NA,
30783108
sparse: bool = False,
30793109
**indexers_kwargs: Any,
30803110
) -> T_Dataset:
@@ -4532,7 +4562,7 @@ def _unstack_full_reindex(
45324562
def unstack(
45334563
self: T_Dataset,
45344564
dim: Hashable | Iterable[Hashable] | None = None,
4535-
fill_value: Any = dtypes.NA,
4565+
fill_value: Any = xrdtypes.NA,
45364566
sparse: bool = False,
45374567
) -> T_Dataset:
45384568
"""
@@ -4677,7 +4707,7 @@ def merge(
46774707
overwrite_vars: Hashable | Iterable[Hashable] = frozenset(),
46784708
compat: CompatOptions = "no_conflicts",
46794709
join: JoinOptions = "outer",
4680-
fill_value: Any = dtypes.NA,
4710+
fill_value: Any = xrdtypes.NA,
46814711
combine_attrs: CombineAttrsOptions = "override",
46824712
) -> T_Dataset:
46834713
"""Merge the arrays of two datasets into a single dataset.
@@ -5886,7 +5916,7 @@ def _set_sparse_data_from_dataframe(
58865916
# missing values and needs a fill_value. For consistency, don't
58875917
# special case the rare exceptions (e.g., dtype=int without a
58885918
# MultiIndex).
5889-
dtype, fill_value = dtypes.maybe_promote(values.dtype)
5919+
dtype, fill_value = xrdtypes.maybe_promote(values.dtype)
58905920
values = np.asarray(values, dtype=dtype)
58915921

58925922
data = COO(
@@ -5924,7 +5954,7 @@ def _set_numpy_data_from_dataframe(
59245954
# fill in missing values:
59255955
# https://stackoverflow.com/a/35049899/809705
59265956
if missing_values:
5927-
dtype, fill_value = dtypes.maybe_promote(values.dtype)
5957+
dtype, fill_value = xrdtypes.maybe_promote(values.dtype)
59285958
data = np.full(shape, fill_value, dtype)
59295959
else:
59305960
# If there are no missing values, keep the existing dtype
@@ -6415,7 +6445,7 @@ def diff(
64156445
def shift(
64166446
self: T_Dataset,
64176447
shifts: Mapping[Any, int] | None = None,
6418-
fill_value: Any = dtypes.NA,
6448+
fill_value: Any = xrdtypes.NA,
64196449
**shifts_kwargs: int,
64206450
) -> T_Dataset:
64216451

@@ -6470,7 +6500,7 @@ def shift(
64706500
for name, var in self.variables.items():
64716501
if name in self.data_vars:
64726502
fill_value_ = (
6473-
fill_value.get(name, dtypes.NA)
6503+
fill_value.get(name, xrdtypes.NA)
64746504
if isinstance(fill_value, dict)
64756505
else fill_value
64766506
)
@@ -6931,7 +6961,9 @@ def differentiate(
69316961
dim = coord_var.dims[0]
69326962
if _contains_datetime_like_objects(coord_var):
69336963
if coord_var.dtype.kind in "mM" and datetime_unit is None:
6934-
datetime_unit, _ = np.datetime_data(coord_var.dtype)
6964+
datetime_unit = cast(
6965+
"DatetimeUnitOptions", np.datetime_data(coord_var.dtype)[0]
6966+
)
69356967
elif datetime_unit is None:
69366968
datetime_unit = "s" # Default to seconds for cftime objects
69376969
coord_var = coord_var._to_numeric(datetime_unit=datetime_unit)
@@ -7744,7 +7776,7 @@ def idxmin(
77447776
self: T_Dataset,
77457777
dim: Hashable | None = None,
77467778
skipna: bool | None = None,
7747-
fill_value: Any = dtypes.NA,
7779+
fill_value: Any = xrdtypes.NA,
77487780
keep_attrs: bool | None = None,
77497781
) -> T_Dataset:
77507782
"""Return the coordinate label of the minimum value along a dimension.
@@ -7841,7 +7873,7 @@ def idxmax(
78417873
self: T_Dataset,
78427874
dim: Hashable | None = None,
78437875
skipna: bool | None = None,
7844-
fill_value: Any = dtypes.NA,
7876+
fill_value: Any = xrdtypes.NA,
78457877
keep_attrs: bool | None = None,
78467878
) -> T_Dataset:
78477879
"""Return the coordinate label of the maximum value along a dimension.

xarray/tests/test_dataarray.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,9 +1327,11 @@ def test_coords(self) -> None:
13271327
]
13281328
da = DataArray(np.random.randn(2, 3), coords, name="foo")
13291329

1330-
assert 2 == len(da.coords)
1330+
# len
1331+
assert len(da.coords) == 2
13311332

1332-
assert ["x", "y"] == list(da.coords)
1333+
# iter
1334+
assert list(da.coords) == ["x", "y"]
13331335

13341336
assert coords[0].identical(da.coords["x"])
13351337
assert coords[1].identical(da.coords["y"])
@@ -1343,6 +1345,7 @@ def test_coords(self) -> None:
13431345
with pytest.raises(KeyError):
13441346
da.coords["foo"]
13451347

1348+
# repr
13461349
expected_repr = dedent(
13471350
"""\
13481351
Coordinates:
@@ -1352,6 +1355,9 @@ def test_coords(self) -> None:
13521355
actual = repr(da.coords)
13531356
assert expected_repr == actual
13541357

1358+
# dtypes
1359+
assert da.coords.dtypes == {"x": np.dtype("int64"), "y": np.dtype("int64")}
1360+
13551361
del da.coords["x"]
13561362
da._indexes = filter_indexes_from_coords(da.xindexes, set(da.coords))
13571363
expected = DataArray(da.values, {"y": [0, 1, 2]}, dims=["x", "y"], name="foo")

0 commit comments

Comments
 (0)