Skip to content

Commit 40bf35b

Browse files
authored
Merge pull request #786 from DHI/isel
Use named arguments to isel/sel
2 parents 6eeeb7a + f9b80a7 commit 40bf35b

File tree

7 files changed

+145
-81
lines changed

7 files changed

+145
-81
lines changed

mikeio/dataset/_dataarray.py

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181
Grid3D,
8282
]
8383

84+
IndexType = Union[int, slice, Sequence[int], np.ndarray, None]
85+
8486

8587
class _DataArraySpectrumToHm0:
8688
def __init__(self, da: "DataArray") -> None:
@@ -584,9 +586,18 @@ def __setitem__(self, key: Any, value: np.ndarray) -> None:
584586

585587
def isel(
586588
self,
587-
idx: int | Sequence[int] | slice | None = None,
589+
idx: IndexType = None,
590+
*,
591+
time: IndexType = None,
592+
x: IndexType = None,
593+
y: IndexType = None,
594+
z: IndexType = None,
595+
element: IndexType = None,
596+
node: IndexType = None,
597+
layer: IndexType = None,
598+
frequency: IndexType = None,
599+
direction: IndexType = None,
588600
axis: int | str = 0,
589-
**kwargs: Any,
590601
) -> "DataArray":
591602
"""Return a new DataArray whose data is given by
592603
integer indexing along the specified dimension(s).
@@ -617,11 +628,17 @@ def isel(
617628
y index, by default None
618629
z : int, optional
619630
z index, by default None
631+
layer: int, optional
632+
layer index, only used in dfsu 3d
633+
direction: int, optional
634+
direction index, only used in sprectra
635+
frequency: int, optional
636+
frequencey index, only used in spectra
637+
node: int, optional
638+
node index, only used in spectra
620639
element : int, optional
621640
Bounding box of coordinates (left lower and right upper)
622641
to be selected, by default None
623-
**kwargs: Any
624-
Not used
625642
626643
Returns
627644
-------
@@ -654,10 +671,23 @@ def isel(
654671
```
655672
656673
"""
657-
if isinstance(self.geometry, Grid2D) and ("x" in kwargs and "y" in kwargs):
658-
idx_x = kwargs["x"]
659-
idx_y = kwargs["y"]
660-
return self.isel(x=idx_x).isel(y=idx_y)
674+
if isinstance(self.geometry, Grid2D) and (x is not None and y is not None):
675+
return self.isel(x=x).isel(y=y)
676+
kwargs = {
677+
k: v
678+
for k, v in dict(
679+
time=time,
680+
x=x,
681+
y=y,
682+
z=z,
683+
element=element,
684+
node=node,
685+
layer=layer,
686+
frequency=frequency,
687+
direction=direction,
688+
).items()
689+
if v is not None
690+
}
661691
for dim in kwargs:
662692
if dim in self.dims:
663693
axis = dim
@@ -698,7 +728,7 @@ def isel(
698728
spatial_axis = axis - 1 if self.dims[0] == "time" else axis
699729
geometry = self.geometry.isel(idx, axis=spatial_axis)
700730

701-
# TOOD this is ugly
731+
# TODO this is ugly
702732
if isinstance(geometry, _GeometryFMLayered):
703733
node_ids, _ = self.geometry._get_nodes_and_table_for_elements(
704734
idx, node_layers="all"
@@ -741,7 +771,12 @@ def sel(
741771
self,
742772
*,
743773
time: str | pd.DatetimeIndex | "DataArray" | None = None,
744-
**kwargs: Any,
774+
x: float | slice | None = None,
775+
y: float | slice | None = None,
776+
z: float | slice | None = None,
777+
coords: np.ndarray | None = None,
778+
area: tuple[float, float, float, float] | None = None,
779+
layers: int | str | Sequence[int | str] | None = None,
745780
) -> "DataArray":
746781
"""Return a new DataArray whose data is given by
747782
selecting index labels along the specified dimension(s).
@@ -780,8 +815,6 @@ def sel(
780815
layer(s) to be selected: "top", "bottom" or layer number
781816
from bottom 0,1,2,... or from the top -1,-2,... or as
782817
list of these; only for layered dfsu, by default None
783-
**kwargs: Any
784-
Additional keyword arguments
785818
786819
Returns
787820
-------
@@ -823,24 +856,32 @@ def sel(
823856
```
824857
825858
"""
859+
# time is not part of kwargs
860+
kwargs = {
861+
k: v
862+
for k, v in dict(
863+
x=x, y=y, z=z, area=area, coords=coords, layers=layers
864+
).items()
865+
if v is not None
866+
}
826867
if any([isinstance(v, slice) for v in kwargs.values()]):
827-
return self._sel_with_slice(kwargs)
868+
return self._sel_with_slice(kwargs) # type: ignore
828869

829870
da = self
830871

831872
# select in space
832873
if len(kwargs) > 0:
833874
idx = self.geometry.find_index(**kwargs)
875+
876+
# TODO this seems fragile
834877
if isinstance(idx, tuple):
835878
# TODO: support for dfs3
836879
assert len(idx) == 2
837-
t_ax_offset = 1 if self._has_time_axis else 0
838880
ii, jj = idx
839881
if jj is not None:
840-
da = da.isel(idx=jj, axis=(0 + t_ax_offset))
882+
da = da.isel(y=jj)
841883
if ii is not None:
842-
sp_axis = 0 if (jj is not None and len(jj) == 1) else 1
843-
da = da.isel(idx=ii, axis=(sp_axis + t_ax_offset))
884+
da = da.isel(x=ii)
844885
else:
845886
da = da.isel(idx, axis="space")
846887

@@ -1866,7 +1907,7 @@ def to_dfs(self, filename: str | Path, **kwargs: Any) -> None:
18661907
Dfs0 only: set the dfs data type of the written data
18671908
to e.g. np.float64, by default: DfsSimpleType.Float (=np.float32)
18681909
**kwargs: Any
1869-
Additional keyword arguments, e.g. dtype for dfs0
1910+
additional arguments passed to the writing function, e.g. dtype for dfs0
18701911
18711912
"""
18721913
self._to_dataset().to_dfs(filename, **kwargs)

mikeio/dataset/_dataset.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444

4545
from ._data_plot import _DatasetPlotter
4646

47+
from ._dataarray import IndexType
48+
4749

4850
class Dataset:
4951
"""Dataset containing one or more DataArrays with common geometry and time.
@@ -431,7 +433,7 @@ def dropna(self) -> "Dataset":
431433
else:
432434
all_index = list(np.intersect1d(all_index, idx))
433435

434-
return self.isel(all_index, axis=0)
436+
return self.isel(time=all_index)
435437

436438
def flipud(self) -> "Dataset":
437439
"""Flip data upside down (on first non-time axis)."""
@@ -476,7 +478,7 @@ def create_data_array(
476478
def create_empty_data(
477479
n_items: int = 1,
478480
n_timesteps: int = 1,
479-
n_elements: int | None = None,
481+
n_elements: IndexType = None,
480482
shape: tuple[int, ...] | None = None,
481483
) -> list:
482484
data = []
@@ -641,15 +643,15 @@ def __getitem__(self, key: Any) -> DataArray | "Dataset":
641643
time_steps = _get_time_idx_list(self.time, key)
642644
if _n_selected_timesteps(self.time, time_steps) == 0:
643645
raise IndexError("No timesteps found!")
644-
return self.isel(time_steps, axis=0)
646+
return self.isel(time=time_steps)
645647
if isinstance(key, slice):
646648
if self._is_slice_time_slice(key):
647649
try:
648650
s = self.time.slice_indexer(key.start, key.stop)
649651
time_steps = list(range(s.start, s.stop))
650652
except ValueError:
651653
time_steps = list(range(*key.indices(len(self.time))))
652-
return self.isel(time_steps, axis=0)
654+
return self.isel(time=time_steps)
653655

654656
if self._multi_indexing_attempted(key):
655657
raise TypeError(
@@ -757,9 +759,18 @@ def __delitem__(self, key: Hashable | int) -> None:
757759

758760
def isel(
759761
self,
760-
idx: int | Sequence[int] | slice | None = None,
762+
idx: IndexType = None,
763+
*,
764+
time: IndexType = None,
765+
x: IndexType = None,
766+
y: IndexType = None,
767+
z: IndexType = None,
768+
element: IndexType = None,
769+
node: IndexType = None,
770+
layer: IndexType = None,
771+
frequency: IndexType = None,
772+
direction: IndexType = None,
761773
axis: int | str = 0,
762-
**kwargs: Any,
763774
) -> "Dataset":
764775
"""Return a new Dataset whose data is given by
765776
integer indexing along the specified dimension(s).
@@ -789,8 +800,14 @@ def isel(
789800
element : int, optional
790801
Bounding box of coordinates (left lower and right upper)
791802
to be selected, by default None
792-
**kwargs: Any
793-
Not used
803+
layer: int, optional
804+
layer index, only used in dfsu 3d
805+
direction: int, optional
806+
direction index, only used in sprectra
807+
frequency: int, optional
808+
frequencey index, only used in spectra
809+
node: int, optional
810+
node index, only used in spectra
794811
795812
Returns
796813
-------
@@ -809,12 +826,36 @@ def isel(
809826
>>> ds3 = ds2.isel(elements=[100,200])
810827
811828
"""
812-
res = [da.isel(idx=idx, axis=axis, **kwargs) for da in self]
829+
# TODO deprecate idx, axis to prefer x= instead
830+
831+
res = [
832+
da.isel(
833+
idx=idx,
834+
axis=axis,
835+
time=time,
836+
x=x,
837+
y=y,
838+
z=z,
839+
element=element,
840+
node=node,
841+
frequency=frequency,
842+
direction=direction,
843+
layer=layer,
844+
)
845+
for da in self
846+
]
813847
return Dataset(data=res, validate=False)
814848

815849
def sel(
816850
self,
817-
**kwargs: Any,
851+
*,
852+
time: Any = None,
853+
x: float | None = None,
854+
y: float | None = None,
855+
z: float | None = None,
856+
coords: np.ndarray | None = None,
857+
area: tuple[float, float, float, float] | None = None,
858+
layers: int | str | Sequence[int | str] | None = None,
818859
) -> "Dataset":
819860
"""Return a new Dataset whose data is given by
820861
selecting index labels along the specified dimension(s).
@@ -853,8 +894,6 @@ def sel(
853894
layer(s) to be selected: "top", "bottom" or layer number
854895
from bottom 0,1,2,... or from the top -1,-2,... or as
855896
list of these; only for layered dfsu, by default None
856-
**kwargs: Any
857-
Not used
858897
859898
Returns
860899
-------
@@ -878,7 +917,10 @@ def sel(
878917
>>> ds.sel(layers="bottom")
879918
880919
"""
881-
res = [da.sel(**kwargs) for da in self]
920+
res = [
921+
da.sel(time=time, x=x, y=y, z=z, coords=coords, area=area, layers=layers)
922+
for da in self
923+
]
882924
return Dataset(data=res, validate=False)
883925

884926
def interp(
@@ -1849,15 +1891,16 @@ def to_dfs(self, filename: str | Path, **kwargs: Any) -> None:
18491891
"""
18501892
filename = str(filename)
18511893

1894+
# TODO is this a candidate for match/case?
18521895
if isinstance(
18531896
self.geometry, (GeometryPoint2D, GeometryPoint3D, GeometryUndefined)
18541897
):
18551898
if self.ndim == 0: # Not very common, but still...
18561899
self._validate_extension(filename, ".dfs0")
1857-
self._to_dfs0(filename, **kwargs)
1900+
self._to_dfs0(filename=filename, **kwargs)
18581901
elif self.ndim == 1 and self[0]._has_time_axis:
18591902
self._validate_extension(filename, ".dfs0")
1860-
self._to_dfs0(filename, **kwargs)
1903+
self._to_dfs0(filename=filename, **kwargs)
18611904
else:
18621905
raise ValueError("Cannot write Dataset with no geometry to file!")
18631906
elif isinstance(self.geometry, Grid2D):
@@ -1885,12 +1928,15 @@ def _validate_extension(filename: str | Path, valid_extension: str) -> None:
18851928
if ext != valid_extension:
18861929
raise ValueError(f"File extension must be {valid_extension}")
18871930

1888-
def _to_dfs0(self, filename: str | Path, **kwargs: Any) -> None:
1931+
def _to_dfs0(
1932+
self,
1933+
filename: str | Path,
1934+
dtype: DfsSimpleType = DfsSimpleType.Float,
1935+
title: str = "",
1936+
) -> None:
18891937
from ..dfs._dfs0 import _write_dfs0
18901938

1891-
dtype = kwargs.get("dtype", DfsSimpleType.Float)
1892-
1893-
_write_dfs0(filename, self, dtype=dtype)
1939+
_write_dfs0(filename, self, dtype=dtype, title=title)
18941940

18951941
def _to_dfs2(self, filename: str | Path) -> None:
18961942
# assumes Grid2D geometry

mikeio/dfs/_dfs0.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def read(
189189
_, time_steps = _valid_timesteps(dfs.FileInfo, time)
190190

191191
if time_steps:
192-
ds = ds.isel(time_steps, axis=0)
192+
ds = ds.isel(time=time_steps)
193193

194194
if sel_time_step_str:
195195
parts = sel_time_step_str.split(",")

tests/test_consistency.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_read_dfs1():
3333
def test_dfs1_isel_t():
3434
ds = mikeio.read("tests/testdata/consistency/oresundHD.dfs1")
3535

36-
ds1 = ds.isel([0, 1], axis="t")
36+
ds1 = ds.isel(time=[0, 1])
3737
assert ds1.dims == ("time", "x")
3838
assert isinstance(ds1.geometry, type(ds.geometry))
3939
assert ds1[0].values[0, 8] == pytest.approx(0.203246)
@@ -42,10 +42,10 @@ def test_dfs1_isel_t():
4242
def test_dfs1_isel_x():
4343
ds = mikeio.read("tests/testdata/consistency/oresundHD.dfs1")
4444

45-
ds1 = ds.isel(8, axis="x")
45+
ds1 = ds.isel(x=8)
4646
assert ds1.dims == ("time",)
4747
assert isinstance(ds1.geometry, GeometryUndefined)
48-
assert ds1[0].isel(0, axis="time").values == pytest.approx(0.203246)
48+
assert ds1[0].isel(time=0).values == pytest.approx(0.203246)
4949

5050

5151
def test_dfs1_sel_t():
@@ -63,13 +63,13 @@ def test_dfs1_sel_x():
6363
ds1 = ds.sel(x=7.8)
6464
assert ds1.dims == ("time",)
6565
assert isinstance(ds1.geometry, GeometryUndefined)
66-
assert ds1[0].isel(0, axis="time").values == pytest.approx(0.203246)
66+
assert ds1[0].isel(time=0).values == pytest.approx(0.203246)
6767

6868
da: DataArray = ds[0]
6969
da1 = da.sel(x=7.8)
7070
assert da1.dims == ("time",)
7171
assert isinstance(ds1.geometry, GeometryUndefined)
72-
assert da1.isel(0, axis="time").values == pytest.approx(0.203246)
72+
assert da1.isel(time=0).values == pytest.approx(0.203246)
7373

7474

7575
def test_dfs1_interp_x():
@@ -78,7 +78,7 @@ def test_dfs1_interp_x():
7878
ds1 = ds.interp(x=7.75)
7979
assert ds1.dims == ("time",)
8080
assert isinstance(ds1.geometry, GeometryUndefined)
81-
assert ds1[0].isel(0, axis="time").values == pytest.approx(0.20202248)
81+
assert ds1[0].isel(time=0).values == pytest.approx(0.20202248)
8282

8383

8484
# Nice to have...

0 commit comments

Comments
 (0)