Skip to content

Commit b7b59d9

Browse files
committed
Infer units for derived fields when quickly defining them
1 parent 405a156 commit b7b59d9

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

yt/fields/derived_field.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,42 @@ def getDependentFields(self):
189189
def name(self):
190190
return f"{self!r}"
191191

192+
@property
193+
def units(self):
194+
def helper(term):
195+
# Get a Unit object, using the unit registry if possible
196+
if hasattr(term, "ds"):
197+
registry = term.ds.unit_registry
198+
else:
199+
registry = None
200+
201+
unit = getattr(term, "units", "1")
202+
203+
return Unit(unit, registry=registry)
204+
205+
match self.op:
206+
case (
207+
operator.eq
208+
| operator.ne
209+
| operator.le
210+
| operator.lt
211+
| operator.ge
212+
| operator.gt
213+
):
214+
# Boolean operator return a dimensionless quantity
215+
return Unit("1")
216+
case operator.add | operator.sub:
217+
# 'Unit's cannot be added/subtracted direclty but they have to
218+
# have the same dimensions
219+
units = [helper(_) for _ in self.terms]
220+
if units and not all(units[0].same_dimensions_as(u) for u in units):
221+
raise ValueError("Incompatible units")
222+
return units[0]
223+
case _:
224+
# Other operators are mul, truediv and neg
225+
units = [helper(_) for _ in self.terms]
226+
return reduce(self.op, units)
227+
192228

193229
class DerivedField(DerivedFieldBase):
194230
"""

yt/fields/field_type_container.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __getattr__(self, attr):
9494

9595
def __setattr__(self, attr, value):
9696
if isinstance(value, DerivedFieldCombination):
97-
self.ds.add_field((self.field_type, attr), value)
97+
self.ds.add_field((self.field_type, attr), value, units=value.units)
9898
else:
9999
super().__setattr__(attr, value)
100100

0 commit comments

Comments
 (0)