Skip to content

Commit 6c7db5d

Browse files
authored
Merge pull request #792 from DHI/math
Support multiplication and division of datasets.
2 parents c743c89 + a4a0bcd commit 6c7db5d

File tree

5 files changed

+117
-120
lines changed

5 files changed

+117
-120
lines changed

docs/user-guide/dataarray.qmd

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,12 @@ da_ratio = da1 / da2
170170
da_ratio.plot(title="", label="Ratio", vmin=0.8, vmax=1.2, levels=9, cmap="coolwarm")
171171
```
172172

173+
## Unit handling
174+
175+
Multiplication and divison of two physical quantities would normally change the unit of the result, but in the case of DataArrays, the type and unit of the result will be the ones of the first operand.
176+
177+
178+
173179
Other methods that also return a DataArray:
174180

175181
* [`interp_like`](`mikeio.DataArray.interp_like`) - Spatio (temporal) interpolation (see example [Dfsu interpolation](../examples/dfsu/spatial_interpolation.qmd)

docs/user-guide/dataset.qmd

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,13 @@ including different ways of *selecting* data:
166166
* ds + value
167167
* ds - value
168168
* ds * value
169-
170-
and + and - between two Datasets (if number of items and shapes conform):
169+
* ds / value
170+
and between two Datasets (if number of items and shapes conform):
171171

172172
* ds1 + ds2
173173
* ds1 - ds2
174+
* ds1 * ds2
175+
* ds1 / ds2
174176

175177
Other methods that also return a Dataset:
176178

mikeio/dataset/_dataarray.py

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,33 +1781,31 @@ def __radd__(self, other: "DataArray" | float) -> "DataArray":
17811781
return self.__add__(other)
17821782

17831783
def __add__(self, other: "DataArray" | float) -> "DataArray":
1784-
return self._apply_math_operation(other, np.add, txt="+")
1784+
return self._apply_math_operation(other, np.add)
17851785

17861786
def __rsub__(self, other: "DataArray" | float) -> "DataArray":
17871787
return other + self.__neg__()
17881788

17891789
def __sub__(self, other: "DataArray" | float) -> "DataArray":
1790-
return self._apply_math_operation(other, np.subtract, txt="-")
1790+
return self._apply_math_operation(other, np.subtract)
17911791

17921792
def __rmul__(self, other: "DataArray" | float) -> "DataArray":
17931793
return self.__mul__(other)
17941794

17951795
def __mul__(self, other: "DataArray" | float) -> "DataArray":
1796-
return self._apply_math_operation(
1797-
other, np.multiply, txt="x"
1798-
) # x in place of *
1796+
return self._apply_math_operation(other, np.multiply)
17991797

18001798
def __pow__(self, other: float) -> "DataArray":
1801-
return self._apply_math_operation(other, np.power, txt="**")
1799+
return self._apply_math_operation(other, np.power)
18021800

18031801
def __truediv__(self, other: "DataArray" | float) -> "DataArray":
1804-
return self._apply_math_operation(other, np.divide, txt="/")
1802+
return self._apply_math_operation(other, np.divide)
18051803

18061804
def __floordiv__(self, other: "DataArray" | float) -> "DataArray":
1807-
return self._apply_math_operation(other, np.floor_divide, txt="//")
1805+
return self._apply_math_operation(other, np.floor_divide)
18081806

18091807
def __mod__(self, other: float) -> "DataArray":
1810-
return self._apply_math_operation(other, np.mod, txt="%")
1808+
return self._apply_math_operation(other, np.mod)
18111809

18121810
def __neg__(self) -> "DataArray":
18131811
return self._apply_unary_math_operation(np.negative)
@@ -1830,7 +1828,9 @@ def _apply_unary_math_operation(self, func: Callable) -> "DataArray":
18301828
return new_da
18311829

18321830
def _apply_math_operation(
1833-
self, other: "DataArray" | float, func: Callable, *, txt: str
1831+
self,
1832+
other: "DataArray" | float,
1833+
func: Callable,
18341834
) -> "DataArray":
18351835
"""Apply a binary math operation with a scalar, an array or another DataArray."""
18361836
try:
@@ -1839,39 +1839,11 @@ def _apply_math_operation(
18391839
except TypeError:
18401840
raise TypeError("Math operation could not be applied to DataArray")
18411841

1842-
# TODO: check if geometry etc match if other is DataArray?
1843-
18441842
new_da = self.copy() # TODO: alternatively: create new dataset (will validate)
18451843
new_da.values = data
18461844

1847-
if not self._keep_EUM_after_math_operation(other, func):
1848-
other_name = other.name if hasattr(other, "name") else "array"
1849-
new_da.item = ItemInfo(
1850-
f"{self.name} {txt} {other_name}", itemtype=EUMType.Undefined
1851-
)
1852-
18531845
return new_da
18541846

1855-
def _keep_EUM_after_math_operation(
1856-
self, other: "DataArray" | float, func: Callable
1857-
) -> bool:
1858-
"""Does the math operation falsify the EUM?"""
1859-
if hasattr(other, "shape") and hasattr(other, "ndim"):
1860-
# other is array-like, so maybe we cannot keep EUM
1861-
if func == np.subtract or func == np.sum:
1862-
# +/-: we may want to keep EUM
1863-
if isinstance(other, DataArray):
1864-
if self.type == other.type and self.unit == other.unit:
1865-
return True
1866-
else:
1867-
return False
1868-
else:
1869-
return True # assume okay, since no EUM
1870-
return False
1871-
1872-
# other is likely scalar, okay to keep EUM
1873-
return True
1874-
18751847
# ============= Logical indexing ===========
18761848

18771849
def __lt__(self, other) -> "DataArray": # type: ignore

mikeio/dataset/_dataset.py

Lines changed: 50 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1300,7 +1300,7 @@ def _concat_time(
13001300
copy: bool = True,
13011301
keep: Literal["last", "first"] = "last",
13021302
) -> "Dataset":
1303-
self._check_all_items_match(other)
1303+
self._check_n_items(other)
13041304
# assuming time is always first dimension we can skip / keep it by bool
13051305
start_dim = int("time" in self.dims)
13061306
if not np.all(
@@ -1348,25 +1348,19 @@ def _concat_time(
13481348
newdata, time=newtime, items=ds.items, geometry=ds.geometry, zn=zn
13491349
)
13501350

1351-
def _check_all_items_match(self, other: "Dataset") -> None:
1351+
def _check_n_items(self, other: "Dataset") -> None:
13521352
if self.n_items != other.n_items:
13531353
raise ValueError(
13541354
f"Number of items must match ({self.n_items} and {other.n_items})"
13551355
)
13561356

1357-
for this, that in zip(self.items, other.items):
1358-
if this.name != that.name:
1359-
raise ValueError(
1360-
f"Item names must match. Item: {this.name} != {that.name}"
1361-
)
1362-
if this.type != that.type:
1363-
raise ValueError(
1364-
f"Item types must match. Item: {this.type} != {that.type}"
1365-
)
1366-
if this.unit != that.unit:
1367-
raise ValueError(
1368-
f"Item units must match. Item: {this.unit} != {that.unit}"
1369-
)
1357+
def _check_datasets_match(self, other: "Dataset") -> None:
1358+
self._check_n_items(other)
1359+
1360+
if not np.all(self.time == other.time):
1361+
raise ValueError("All timesteps must match")
1362+
if self.shape != other.shape:
1363+
raise ValueError("shape must match")
13701364

13711365
# ============ aggregate =============
13721366

@@ -1765,83 +1759,63 @@ def __radd__(self, other: "Dataset" | float) -> "Dataset":
17651759

17661760
def __add__(self, other: "Dataset" | float) -> "Dataset":
17671761
if isinstance(other, self.__class__):
1768-
return self._add_dataset(other)
1762+
return self._binary_op(other, operator="+")
17691763
else:
1770-
# float-like
1771-
return self._add_value(other) # type: ignore
1764+
return self._scalar_op(other, operator="+") # type: ignore
17721765

17731766
def __rsub__(self, other: "Dataset" | float) -> "Dataset":
1774-
ds = self.__mul__(-1.0)
1775-
return other + ds
1767+
ds = self._scalar_op(-1.0, operator="*")
1768+
return ds._scalar_op(other, operator="+") # type: ignore
17761769

17771770
def __sub__(self, other: "Dataset" | float) -> "Dataset":
17781771
if isinstance(other, self.__class__):
1779-
return self._add_dataset(other, sign=-1.0)
1772+
return self._binary_op(other, operator="-")
17801773
else:
1781-
return self._add_value(-other) # type: ignore
1774+
return self._scalar_op(-other, operator="+") # type: ignore
17821775

17831776
def __rmul__(self, other: "Dataset" | float) -> "Dataset":
17841777
return self.__mul__(other)
17851778

17861779
def __mul__(self, other: "Dataset" | float) -> "Dataset":
17871780
if isinstance(other, self.__class__):
1788-
raise ValueError("Multiplication is not possible for two Datasets")
1781+
return self._binary_op(other, operator="*")
17891782
else:
1790-
return self._multiply_value(other) # type: ignore
1791-
1792-
def _add_dataset(self, other: "Dataset", sign: float = 1.0) -> "Dataset":
1793-
self._check_datasets_match(other)
1794-
try:
1795-
data = [
1796-
self[x].to_numpy() + sign * other[y].to_numpy()
1797-
for x, y in zip(self.items, other.items)
1798-
]
1799-
except TypeError:
1800-
raise TypeError("Could not add data in Dataset")
1801-
newds = self.copy()
1802-
for new, old in zip(newds, data):
1803-
new.values = old
1804-
return newds
1805-
1806-
def _check_datasets_match(self, other: "Dataset") -> None:
1807-
self._check_all_items_match(other)
1808-
1809-
if not np.all(self.time == other.time):
1810-
raise ValueError("All timesteps must match")
1811-
if self.shape != other.shape:
1812-
raise ValueError("shape must match")
1783+
return self._scalar_op(other, operator="*") # type: ignore
18131784

1814-
def _add_value(self, value: float) -> "Dataset":
1815-
try:
1816-
data = [value + self[x].to_numpy() for x in self.items]
1817-
except TypeError:
1818-
raise TypeError(f"{value} could not be added to Dataset")
1819-
items = deepcopy(self.items)
1820-
time = self.time.copy()
1821-
return Dataset(
1822-
data,
1823-
time=time,
1824-
items=items,
1825-
geometry=self.geometry,
1826-
zn=self._zn,
1827-
validate=False,
1828-
)
1785+
def __truediv__(self, other: "Dataset" | float) -> "Dataset":
1786+
if isinstance(other, self.__class__):
1787+
return self._binary_op(other, operator="/")
1788+
else:
1789+
return self._scalar_op(other, operator="/") # type: ignore
18291790

1830-
def _multiply_value(self, value: float) -> "Dataset":
1831-
try:
1832-
data = [value * self[x].to_numpy() for x in self.items]
1833-
except TypeError:
1834-
raise TypeError(f"{value} could not be multiplied to Dataset")
1835-
items = deepcopy(self.items)
1836-
time = self.time.copy()
1837-
return Dataset(
1838-
data,
1839-
time=time,
1840-
items=items,
1841-
geometry=self.geometry,
1842-
zn=self._zn,
1843-
validate=False,
1844-
)
1791+
def _binary_op(self, other: "Dataset", operator: str) -> "Dataset":
1792+
self._check_datasets_match(other)
1793+
match operator:
1794+
case "+":
1795+
data = [x + y for x, y in zip(self, other)]
1796+
case "-":
1797+
data = [x - y for x, y in zip(self, other)]
1798+
case "*":
1799+
data = [x * y for x, y in zip(self, other)]
1800+
case "/":
1801+
data = [x / y for x, y in zip(self, other)]
1802+
case _:
1803+
raise ValueError(f"Unsupported operator: {operator}")
1804+
return Dataset(data)
1805+
1806+
def _scalar_op(self, value: float, operator: str) -> "Dataset":
1807+
match operator:
1808+
case "+":
1809+
data = [x + value for x in self]
1810+
case "-":
1811+
data = [x - value for x in self]
1812+
case "*":
1813+
data = [x * value for x in self]
1814+
case "/":
1815+
data = [x / value for x in self]
1816+
case _:
1817+
raise ValueError(f"Unsupported operator: {operator}")
1818+
return Dataset(data)
18451819

18461820
# ===============================================
18471821

tests/test_dataset.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1182,10 +1182,14 @@ def test_add_dataset(ds1, ds2):
11821182

11831183
ds2b = ds2.copy()
11841184
ds2b[0].item = ItemInfo(EUMType.Wind_Velocity)
1185-
with pytest.raises(ValueError):
1186-
# item type does not match
1187-
ds1 + ds2b
1188-
1185+
# item type does not match, but we don't care about the item type, item is defined by the first dataset
1186+
ds3 = ds2b + ds1
1187+
assert ds3.items[0].type == EUMType.Wind_Velocity
1188+
assert ds3.items[0].name == ds2b.items[0].name
1189+
1190+
ds4 = ds1 + ds2b
1191+
assert ds4.items[0].type == EUMType.Undefined
1192+
assert ds4.items[0].name == ds1.items[0].name
11891193
ds2c = ds2.copy()
11901194
tt = ds2c.time.to_numpy()
11911195
tt[-1] = tt[-1] + np.timedelta64(1, "s")
@@ -1201,6 +1205,45 @@ def test_sub_dataset(ds1, ds2):
12011205
assert np.all(ds3[1].to_numpy() == 1.8)
12021206

12031207

1208+
def test_multiply_dataset(ds1, ds2):
1209+
dsa = mikeio.Dataset(
1210+
{
1211+
"Foo": mikeio.DataArray(
1212+
[1, 2, 3], item=mikeio.ItemInfo("Foo", EUMType.Water_Level)
1213+
)
1214+
}
1215+
)
1216+
dsb = mikeio.Dataset({"Foo": mikeio.DataArray([4, 5, 6])})
1217+
dsr = dsa * dsb
1218+
assert np.all(dsr.Foo.to_numpy() == np.array([4, 10, 18]))
1219+
assert dsr.Foo.type == EUMType.Water_Level
1220+
1221+
1222+
def test_multiply_number_of_items_datasets_must_match():
1223+
dsa = mikeio.Dataset(
1224+
{"Foo": mikeio.DataArray([1, 2, 3]), "Bar": mikeio.DataArray([1, 2, 3])}
1225+
)
1226+
dsb = mikeio.Dataset({"Bar": mikeio.DataArray([4, 5, 6])})
1227+
with pytest.raises(ValueError, match="Number of items"):
1228+
dsa * dsb
1229+
1230+
1231+
def test_divide_dataset(ds1, ds2):
1232+
ds_nom = mikeio.Dataset({"Foo": mikeio.DataArray([1, 2, 3])})
1233+
ds_denom = mikeio.Dataset({"Foo": mikeio.DataArray([4, 5, 6])})
1234+
ds3 = ds_nom / ds_denom
1235+
assert np.all(ds3[0].to_numpy() == np.array([0.25, 0.4, 0.5]))
1236+
1237+
1238+
def test_divide_number_of_items_datasets_must_match():
1239+
dsa = mikeio.Dataset(
1240+
{"Foo": mikeio.DataArray([1, 2, 3]), "Bar": mikeio.DataArray([1, 2, 3])}
1241+
)
1242+
dsb = mikeio.Dataset({"Bar": mikeio.DataArray([4, 5, 6])})
1243+
with pytest.raises(ValueError, match="Number of items"):
1244+
dsa / dsb
1245+
1246+
12041247
def test_non_equidistant():
12051248
nt = 4
12061249
d = np.random.uniform(size=nt)

0 commit comments

Comments
 (0)