Skip to content

Commit 5b07296

Browse files
committed
Use cleaner implementation
1 parent 8afa1ef commit 5b07296

File tree

4 files changed

+142
-139
lines changed

4 files changed

+142
-139
lines changed

yt/data_objects/selection_objects/cut_region.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
YTSelectionContainer3D,
99
)
1010
from yt.data_objects.static_output import Dataset
11-
from yt.fields.derived_field import DerivedField
11+
from yt.fields.derived_field import DerivedFieldCombination
1212
from yt.funcs import iter_fields, validate_object
1313
from yt.geometry.selection_routines import points_in_cells
1414
from yt.utilities.exceptions import YTIllDefinedCutRegion
@@ -57,7 +57,8 @@ def __init__(
5757
validate_object(data_source, YTSelectionContainer)
5858
conditionals = list(always_iterable(conditionals))
5959
for condition in conditionals:
60-
validate_object(condition, (str, DerivedField))
60+
validate_object(condition, (str, DerivedFieldCombination))
61+
6162
validate_object(ds, Dataset)
6263
validate_object(field_parameters, dict)
6364
validate_object(base_object, YTSelectionContainer)
@@ -83,8 +84,8 @@ def __init__(
8384
def _check_filter_fields(self):
8485
fields = []
8586
for cond in self.conditionals:
86-
if isinstance(cond, DerivedField):
87-
fields.append(cond.name)
87+
if isinstance(cond, DerivedFieldCombination):
88+
fields.extend(cond.getDependentFields())
8889
continue
8990

9091
for field in re.findall(r"\[([A-Za-z0-9_,.'\"\(\)]+)\]", cond):
@@ -130,8 +131,8 @@ def blocks(self):
130131
m = m.copy()
131132
with obj._field_parameter_state(self.field_parameters):
132133
for cond in self.conditionals:
133-
if isinstance(cond, DerivedField):
134-
ss = cond(obj)
134+
if isinstance(cond, DerivedFieldCombination):
135+
ss = cond(None, obj)
135136
else:
136137
ss = eval(cond)
137138
m &= ss
@@ -152,8 +153,8 @@ def _cond_ind(self):
152153
locals["obj"] = obj
153154
with obj._field_parameter_state(self.field_parameters):
154155
for cond in self.conditionals:
155-
if isinstance(cond, DerivedField):
156-
res = cond(obj)
156+
if isinstance(cond, DerivedFieldCombination):
157+
res = cond(None, obj)
157158
else:
158159
res = eval(cond, locals)
159160
if ind is None:

yt/data_objects/static_output.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
from yt.data_objects.particle_filters import ParticleFilter, filter_registry
3636
from yt.data_objects.region_expression import RegionExpression
3737
from yt.data_objects.unions import ParticleUnion
38-
from yt.fields.derived_field import DerivedField, ValidateSpatial
38+
from yt.fields.derived_field import (
39+
DerivedField,
40+
DerivedFieldCombination,
41+
ValidateSpatial,
42+
)
3943
from yt.fields.field_type_container import FieldTypeContainer
4044
from yt.fields.fluid_fields import setup_gradient_fields
4145
from yt.funcs import iter_fields, mylog, set_intersection, setdefaultattr
@@ -883,6 +887,9 @@ def add_particle_filter(self, filter):
883887
used = self._setup_filtered_type(f)
884888
if used:
885889
filter = f
890+
elif isinstance(filter, DerivedFieldCombination):
891+
filter_registry[filter.name] = filter
892+
used = self._setup_filtered_type(filter)
886893
else:
887894
used = self._setup_filtered_type(filter)
888895
if not used:
@@ -1757,15 +1764,12 @@ def add_field(
17571764
"""
17581765
from yt.fields.field_functions import validate_field_function
17591766

1760-
if not isinstance(function, DerivedField):
1767+
if not isinstance(function, DerivedFieldCombination):
17611768
if sampling_type is None:
17621769
raise ValueError("You must specify a sampling_type for the field.")
17631770
validate_field_function(function)
17641771
else:
17651772
sampling_type = function.sampling_type
1766-
kwargs.setdefault("units", function.units)
1767-
1768-
function = function._function
17691773

17701774
self.index
17711775
if force_override and name in self.index.field_list:

yt/fields/derived_field.py

Lines changed: 122 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import abc
12
import contextlib
23
import inspect
34
import operator
45
import re
56
from collections.abc import Callable, Iterable
6-
from typing import Optional, Union
7+
from functools import reduce
8+
from typing import Optional
79

810
from more_itertools import always_iterable
911

@@ -59,7 +61,125 @@ def _DeprecatedFieldFunc(field, data):
5961
return _DeprecatedFieldFunc
6062

6163

62-
class DerivedField:
64+
class DerivedFieldBase(abc.ABC):
65+
@abc.abstractmethod
66+
def __call__(self, field, data):
67+
pass
68+
69+
@abc.abstractmethod
70+
def __repr__(self) -> str:
71+
pass
72+
73+
# Multiplication (left and right side)
74+
def __mul__(self, other) -> "DerivedFieldCombination":
75+
return DerivedFieldCombination([self, other], op=operator.mul)
76+
77+
def __rmul__(self, other) -> "DerivedFieldCombination":
78+
return DerivedFieldCombination([self, other], op=operator.mul)
79+
80+
# Division (left side)
81+
def __truediv__(self, other) -> "DerivedFieldCombination":
82+
return DerivedFieldCombination([self, other], op=operator.truediv)
83+
84+
def __rtruediv__(self, other) -> "DerivedFieldCombination":
85+
return DerivedFieldCombination([other, self], op=operator.truediv)
86+
87+
# Addition (left and right side)
88+
def __add__(self, other) -> "DerivedFieldCombination":
89+
return DerivedFieldCombination([self, other], op=operator.add)
90+
91+
def __radd__(self, other) -> "DerivedFieldCombination":
92+
return DerivedFieldCombination([self, other], op=operator.add)
93+
94+
# Subtraction (left and right side)
95+
def __sub__(self, other) -> "DerivedFieldCombination":
96+
return DerivedFieldCombination([self, other], op=operator.sub)
97+
98+
def __rsub__(self, other) -> "DerivedFieldCombination":
99+
return DerivedFieldCombination([other, self], op=operator.sub)
100+
101+
# Unary minus
102+
def __neg__(self) -> "DerivedFieldCombination":
103+
return DerivedFieldCombination([self], op=operator.neg)
104+
105+
# Comparison operators
106+
def __leq__(self, other) -> "DerivedFieldCombination":
107+
return DerivedFieldCombination([self, other], op=operator.le)
108+
109+
def __lt__(self, other) -> "DerivedFieldCombination":
110+
return DerivedFieldCombination([self, other], op=operator.lt)
111+
112+
def __geq__(self, other) -> "DerivedFieldCombination":
113+
return DerivedFieldCombination([self, other], op=operator.ge)
114+
115+
def __gt__(self, other) -> "DerivedFieldCombination":
116+
return DerivedFieldCombination([self, other], op=operator.gt)
117+
118+
# def __eq__(self, other) -> "DerivedFieldCombination":
119+
# return DerivedFieldCombination([self, other], op=operator.eq)
120+
121+
def __ne__(self, other) -> "DerivedFieldCombination":
122+
return DerivedFieldCombination([self, other], op=operator.ne)
123+
124+
125+
class DerivedFieldCombination(DerivedFieldBase):
126+
sampling_type: str | None
127+
terms: list
128+
op: Callable | None
129+
130+
def __init__(self, terms: list, op=None):
131+
if not terms:
132+
raise ValueError("DerivedFieldCombination requires at least one term.")
133+
134+
# Make sure all terms have the same sampling type
135+
sampling_types = set()
136+
for term in terms:
137+
if isinstance(term, DerivedField):
138+
sampling_types.add(term.sampling_type)
139+
140+
if len(sampling_types) > 1:
141+
raise ValueError(
142+
"All terms in a DerivedFieldCombination must "
143+
"have the same sampling type."
144+
)
145+
self.sampling_type = sampling_types.pop() if sampling_types else None
146+
self.terms = terms
147+
self.op = op
148+
149+
def __call__(self, field, data):
150+
"""
151+
Return the value of the field in a given data object.
152+
"""
153+
qties = []
154+
for term in self.terms:
155+
if isinstance(term, DerivedField):
156+
qties.append(data[term.name])
157+
elif isinstance(term, DerivedFieldCombination):
158+
qties.append(term(field, data))
159+
else:
160+
qties.append(term)
161+
162+
if len(qties) == 1:
163+
return self.op(qties[0])
164+
else:
165+
return reduce(self.op, qties)
166+
167+
def __repr__(self):
168+
return f"DerivedFieldCombination(terms={self.terms!r}, op={self.op!r})"
169+
170+
def getDependentFields(self):
171+
fields = []
172+
for term in self.terms:
173+
if isinstance(term, DerivedField):
174+
fields.append(term.name)
175+
elif isinstance(term, DerivedFieldCombination):
176+
fields.extend(term.getDependentFields())
177+
else:
178+
continue
179+
return fields
180+
181+
182+
class DerivedField(DerivedFieldBase):
63183
"""
64184
This is the base class used to describe a cell-by-cell derived field.
65185
@@ -499,128 +619,6 @@ def __copy__(self):
499619
nodal_flag=self.nodal_flag,
500620
)
501621

502-
def _operator(
503-
self, other: Union["DerivedField", float], op: Callable
504-
) -> "DerivedField":
505-
my_units = self.ds.get_unit_from_registry(self.units)
506-
if isinstance(other, DerivedField):
507-
if self.sampling_type != other.sampling_type:
508-
raise TypeError(
509-
f"Cannot {op} fields with different sampling types: "
510-
f"{self.sampling_type} and {other.sampling_type}"
511-
)
512-
513-
def wrapped(field, data):
514-
return op(self(data), other(data))
515-
516-
other_name = other.name[1]
517-
other_units = self.ds.get_unit_from_registry(other.units)
518-
519-
else:
520-
# Special case when passing (value, "unit") tuple
521-
if isinstance(other, tuple) and len(other) == 2:
522-
other = self.ds.quan(*other)
523-
524-
def wrapped(field, data):
525-
return op(self(data), other)
526-
527-
other_name = str(other)
528-
other_units = getattr(other, "units", self.ds.get_unit_from_registry("1"))
529-
530-
if op in (operator.add, operator.sub, operator.eq):
531-
assert my_units.same_dimensions_as(other_units)
532-
new_units = my_units
533-
elif op in (operator.mul, operator.truediv):
534-
new_units = op(my_units, other_units)
535-
elif op in (operator.le, operator.lt, operator.ge, operator.gt, operator.ne):
536-
# Comparison yield unitless fields
537-
new_units = Unit("1")
538-
else:
539-
raise TypeError(f"Unsupported operator {op} for DerivedField")
540-
541-
return DerivedField(
542-
name=(self.name[0], f"{self.name[1]}_{op.__name__}_{other_name}"),
543-
sampling_type=self.sampling_type,
544-
function=wrapped,
545-
units=new_units,
546-
ds=self.ds,
547-
)
548-
549-
# Multiplication (left and right side)
550-
def __mul__(self, other: Union["DerivedField", float]) -> "DerivedField":
551-
return self._operator(other, op=operator.mul)
552-
553-
def __rmul__(self, other: Union["DerivedField", float]) -> "DerivedField":
554-
return self._operator(other, op=operator.mul)
555-
556-
# Division (left side)
557-
def __truediv__(self, other: Union["DerivedField", float]) -> "DerivedField":
558-
return self._operator(other, op=operator.truediv)
559-
560-
# Addition (left and right side)
561-
def __add__(self, other: Union["DerivedField", float]) -> "DerivedField":
562-
return self._operator(other, op=operator.add)
563-
564-
def __radd__(self, other: Union["DerivedField", float]) -> "DerivedField":
565-
return self._operator(other, op=operator.add)
566-
567-
# Subtraction (left and right side)
568-
def __sub__(self, other: Union["DerivedField", float]) -> "DerivedField":
569-
return self._operator(other, op=operator.sub)
570-
571-
def __rsub__(self, other: Union["DerivedField", float]) -> "DerivedField":
572-
return self._operator(-other, op=operator.add)
573-
574-
# Unary minus
575-
def __neg__(self) -> "DerivedField":
576-
def wrapped(field, data):
577-
return -self(data)
578-
579-
return DerivedField(
580-
name=(self.name[0], f"neg_{self.name[1]}"),
581-
sampling_type=self.sampling_type,
582-
function=wrapped,
583-
units=self.units,
584-
ds=self.ds,
585-
)
586-
587-
# Division (right side, a bit more complex)
588-
def __rtruediv__(self, other: Union["DerivedField", float]) -> "DerivedField":
589-
units = self.ds.get_unit_from_registry(self.units)
590-
591-
def wrapped(field, data):
592-
return 1 / self(data)
593-
594-
inverse_self = DerivedField(
595-
name=(self.name[0], f"inverse_{self.name[1]}"),
596-
sampling_type=self.sampling_type,
597-
function=wrapped,
598-
units=units**-1,
599-
ds=self.ds,
600-
)
601-
602-
return inverse_self * other
603-
604-
# Comparison operators
605-
def __leq__(self, other: Union["DerivedField", float]) -> "DerivedField":
606-
return self._operator(other, op=operator.le)
607-
608-
def __lt__(self, other: Union["DerivedField", float]) -> "DerivedField":
609-
return self._operator(other, op=operator.lt)
610-
611-
def __geq__(self, other: Union["DerivedField", float]) -> "DerivedField":
612-
return self._operator(other, op=operator.ge)
613-
614-
def __gt__(self, other: Union["DerivedField", float]) -> "DerivedField":
615-
return self._operator(other, op=operator.gt)
616-
617-
# Somehow, makes yt not work?
618-
# def __eq__(self, other: Union["DerivedField", float]) -> "DerivedField":
619-
# return self._operator(other, op=operator.eq)
620-
621-
def __ne__(self, other: Union["DerivedField", float]) -> "DerivedField":
622-
return self._operator(other, op=operator.ne)
623-
624622

625623
class FieldValidator:
626624
"""

yt/fields/field_type_container.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from functools import cached_property
99

1010
from yt._maintenance.ipython_compat import IPYWIDGETS_ENABLED
11-
from yt.fields.derived_field import DerivedField
11+
from yt.fields.derived_field import DerivedField, DerivedFieldCombination
1212

1313

1414
def _fill_values(values):
@@ -93,7 +93,7 @@ def __getattr__(self, attr):
9393
return ds.field_info[ft, attr]
9494

9595
def __setattr__(self, attr, value):
96-
if isinstance(value, DerivedField):
96+
if isinstance(value, DerivedFieldCombination):
9797
self.ds.add_field((self.field_type, attr), value)
9898
else:
9999
super().__setattr__(attr, value)

0 commit comments

Comments
 (0)