@@ -1300,7 +1300,7 @@ def _concat_time(
1300
1300
copy : bool = True ,
1301
1301
keep : Literal ["last" , "first" ] = "last" ,
1302
1302
) -> "Dataset" :
1303
- self ._check_all_items_match (other )
1303
+ self ._check_n_items (other )
1304
1304
# assuming time is always first dimension we can skip / keep it by bool
1305
1305
start_dim = int ("time" in self .dims )
1306
1306
if not np .all (
@@ -1348,25 +1348,19 @@ def _concat_time(
1348
1348
newdata , time = newtime , items = ds .items , geometry = ds .geometry , zn = zn
1349
1349
)
1350
1350
1351
- def _check_all_items_match (self , other : "Dataset" ) -> None :
1351
+ def _check_n_items (self , other : "Dataset" ) -> None :
1352
1352
if self .n_items != other .n_items :
1353
1353
raise ValueError (
1354
1354
f"Number of items must match ({ self .n_items } and { other .n_items } )"
1355
1355
)
1356
1356
1357
- for this , that in zip (self .items , other .items ):
1358
- if this .name != that .name :
1359
- raise ValueError (
1360
- f"Item names must match. Item: { this .name } != { that .name } "
1361
- )
1362
- if this .type != that .type :
1363
- raise ValueError (
1364
- f"Item types must match. Item: { this .type } != { that .type } "
1365
- )
1366
- if this .unit != that .unit :
1367
- raise ValueError (
1368
- f"Item units must match. Item: { this .unit } != { that .unit } "
1369
- )
1357
+ def _check_datasets_match (self , other : "Dataset" ) -> None :
1358
+ self ._check_n_items (other )
1359
+
1360
+ if not np .all (self .time == other .time ):
1361
+ raise ValueError ("All timesteps must match" )
1362
+ if self .shape != other .shape :
1363
+ raise ValueError ("shape must match" )
1370
1364
1371
1365
# ============ aggregate =============
1372
1366
@@ -1765,83 +1759,63 @@ def __radd__(self, other: "Dataset" | float) -> "Dataset":
1765
1759
1766
1760
def __add__ (self , other : "Dataset" | float ) -> "Dataset" :
1767
1761
if isinstance (other , self .__class__ ):
1768
- return self ._add_dataset (other )
1762
+ return self ._binary_op (other , operator = "+" )
1769
1763
else :
1770
- # float-like
1771
- return self ._add_value (other ) # type: ignore
1764
+ return self ._scalar_op (other , operator = "+" ) # type: ignore
1772
1765
1773
1766
def __rsub__ (self , other : "Dataset" | float ) -> "Dataset" :
1774
- ds = self .__mul__ (- 1.0 )
1775
- return other + ds
1767
+ ds = self ._scalar_op (- 1.0 , operator = "*" )
1768
+ return ds . _scalar_op ( other , operator = "+" ) # type: ignore
1776
1769
1777
1770
def __sub__ (self , other : "Dataset" | float ) -> "Dataset" :
1778
1771
if isinstance (other , self .__class__ ):
1779
- return self ._add_dataset (other , sign = - 1.0 )
1772
+ return self ._binary_op (other , operator = "-" )
1780
1773
else :
1781
- return self ._add_value (- other ) # type: ignore
1774
+ return self ._scalar_op (- other , operator = "+" ) # type: ignore
1782
1775
1783
1776
def __rmul__ (self , other : "Dataset" | float ) -> "Dataset" :
1784
1777
return self .__mul__ (other )
1785
1778
1786
1779
def __mul__ (self , other : "Dataset" | float ) -> "Dataset" :
1787
1780
if isinstance (other , self .__class__ ):
1788
- raise ValueError ( "Multiplication is not possible for two Datasets " )
1781
+ return self . _binary_op ( other , operator = "* " )
1789
1782
else :
1790
- return self ._multiply_value (other ) # type: ignore
1791
-
1792
- def _add_dataset (self , other : "Dataset" , sign : float = 1.0 ) -> "Dataset" :
1793
- self ._check_datasets_match (other )
1794
- try :
1795
- data = [
1796
- self [x ].to_numpy () + sign * other [y ].to_numpy ()
1797
- for x , y in zip (self .items , other .items )
1798
- ]
1799
- except TypeError :
1800
- raise TypeError ("Could not add data in Dataset" )
1801
- newds = self .copy ()
1802
- for new , old in zip (newds , data ):
1803
- new .values = old
1804
- return newds
1805
-
1806
- def _check_datasets_match (self , other : "Dataset" ) -> None :
1807
- self ._check_all_items_match (other )
1808
-
1809
- if not np .all (self .time == other .time ):
1810
- raise ValueError ("All timesteps must match" )
1811
- if self .shape != other .shape :
1812
- raise ValueError ("shape must match" )
1783
+ return self ._scalar_op (other , operator = "*" ) # type: ignore
1813
1784
1814
- def _add_value (self , value : float ) -> "Dataset" :
1815
- try :
1816
- data = [value + self [x ].to_numpy () for x in self .items ]
1817
- except TypeError :
1818
- raise TypeError (f"{ value } could not be added to Dataset" )
1819
- items = deepcopy (self .items )
1820
- time = self .time .copy ()
1821
- return Dataset (
1822
- data ,
1823
- time = time ,
1824
- items = items ,
1825
- geometry = self .geometry ,
1826
- zn = self ._zn ,
1827
- validate = False ,
1828
- )
1785
+ def __truediv__ (self , other : "Dataset" | float ) -> "Dataset" :
1786
+ if isinstance (other , self .__class__ ):
1787
+ return self ._binary_op (other , operator = "/" )
1788
+ else :
1789
+ return self ._scalar_op (other , operator = "/" ) # type: ignore
1829
1790
1830
- def _multiply_value (self , value : float ) -> "Dataset" :
1831
- try :
1832
- data = [value * self [x ].to_numpy () for x in self .items ]
1833
- except TypeError :
1834
- raise TypeError (f"{ value } could not be multiplied to Dataset" )
1835
- items = deepcopy (self .items )
1836
- time = self .time .copy ()
1837
- return Dataset (
1838
- data ,
1839
- time = time ,
1840
- items = items ,
1841
- geometry = self .geometry ,
1842
- zn = self ._zn ,
1843
- validate = False ,
1844
- )
1791
+ def _binary_op (self , other : "Dataset" , operator : str ) -> "Dataset" :
1792
+ self ._check_datasets_match (other )
1793
+ match operator :
1794
+ case "+" :
1795
+ data = [x + y for x , y in zip (self , other )]
1796
+ case "-" :
1797
+ data = [x - y for x , y in zip (self , other )]
1798
+ case "*" :
1799
+ data = [x * y for x , y in zip (self , other )]
1800
+ case "/" :
1801
+ data = [x / y for x , y in zip (self , other )]
1802
+ case _:
1803
+ raise ValueError (f"Unsupported operator: { operator } " )
1804
+ return Dataset (data )
1805
+
1806
+ def _scalar_op (self , value : float , operator : str ) -> "Dataset" :
1807
+ match operator :
1808
+ case "+" :
1809
+ data = [x + value for x in self ]
1810
+ case "-" :
1811
+ data = [x - value for x in self ]
1812
+ case "*" :
1813
+ data = [x * value for x in self ]
1814
+ case "/" :
1815
+ data = [x / value for x in self ]
1816
+ case _:
1817
+ raise ValueError (f"Unsupported operator: { operator } " )
1818
+ return Dataset (data )
1845
1819
1846
1820
# ===============================================
1847
1821
0 commit comments