Skip to content

Commit 77aafba

Browse files
committed
Allow boolean equal to define units
1 parent 80bec7a commit 77aafba

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

yt/data_objects/tests/test_add_field.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,13 @@ def test_add_field_quick_syntax2():
243243
ds.r["gas", "temperature"] * ds.units.kb / ds.r["gas", "volume"],
244244
)
245245

246+
# Returning boolean
247+
dx_min = ds.r["index", "dx"].min()
248+
ds.fields.gas.smallest_cells = ds.fields.gas.dx == dx_min
249+
np.testing.assert_allclose(
250+
ds.r["gas", "smallest_cells"].value, (dx_min == ds.r["gas", "dx"])
251+
)
252+
246253

247254
@pytest.fixture()
248255
def capturable_logger(caplog):

yt/fields/derived_field.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,23 @@ def __geq__(self, other) -> "DerivedFieldCombination":
115115
def __gt__(self, other) -> "DerivedFieldCombination":
116116
return DerivedFieldCombination([self, other], op=operator.gt)
117117

118-
# def __eq__(self, other) -> "DerivedFieldCombination":
119-
# return DerivedFieldCombination([self, other], op=operator.eq)
118+
def __eq__(self, other) -> "DerivedFieldCombination":
119+
return DerivedFieldCombination([self, other], op=operator.eq)
120120

121121
def __ne__(self, other) -> "DerivedFieldCombination": # type: ignore[override]
122122
return DerivedFieldCombination([self, other], op=operator.ne)
123123

124+
@abc.abstractmethod
125+
def __hash__(self) -> int:
126+
pass
127+
124128

125129
class DerivedFieldCombination(DerivedFieldBase):
126130
sampling_type: str | None
127131
terms: list
128-
op: Callable | None
132+
op: Callable
129133

130-
def __init__(self, terms: list, op=None):
134+
def __init__(self, terms: list, op: Callable):
131135
if not terms:
132136
raise ValueError("DerivedFieldCombination requires at least one term.")
133137

@@ -146,6 +150,9 @@ def __init__(self, terms: list, op=None):
146150
self.terms = terms
147151
self.op = op
148152

153+
def __hash__(self):
154+
return hash((self.sampling_type, tuple(self.terms), self.op))
155+
149156
def __call__(self, field, data):
150157
"""
151158
Return the value of the field in a given data object.
@@ -307,6 +314,26 @@ def __init__(
307314
self._shared_aliases_list = alias._shared_aliases_list
308315
self._shared_aliases_list.append(self)
309316

317+
def __hash__(self):
318+
return hash(
319+
(
320+
self.name,
321+
self.sampling_type,
322+
self._function,
323+
self.units,
324+
self.take_log,
325+
tuple(self.validators),
326+
self.vector_field,
327+
self.display_field,
328+
self.not_in_all,
329+
self.display_name,
330+
self.output_units,
331+
self.dimensions,
332+
self.ds,
333+
tuple(self.nodal_flag),
334+
)
335+
)
336+
310337
def _copy_def(self):
311338
dd = {}
312339
dd["name"] = self.name

0 commit comments

Comments
 (0)