|
| 1 | +import abc |
1 | 2 | import contextlib
|
2 | 3 | import inspect
|
3 | 4 | import operator
|
4 | 5 | import re
|
5 | 6 | from collections.abc import Callable, Iterable
|
6 |
| -from typing import Optional, Union |
| 7 | +from functools import reduce |
| 8 | +from typing import Optional |
7 | 9 |
|
8 | 10 | from more_itertools import always_iterable
|
9 | 11 |
|
@@ -59,7 +61,125 @@ def _DeprecatedFieldFunc(field, data):
|
59 | 61 | return _DeprecatedFieldFunc
|
60 | 62 |
|
61 | 63 |
|
62 |
| -class DerivedField: |
| 64 | +class DerivedFieldBase(abc.ABC): |
| 65 | + @abc.abstractmethod |
| 66 | + def __call__(self, field, data): |
| 67 | + pass |
| 68 | + |
| 69 | + @abc.abstractmethod |
| 70 | + def __repr__(self) -> str: |
| 71 | + pass |
| 72 | + |
| 73 | + # Multiplication (left and right side) |
| 74 | + def __mul__(self, other) -> "DerivedFieldCombination": |
| 75 | + return DerivedFieldCombination([self, other], op=operator.mul) |
| 76 | + |
| 77 | + def __rmul__(self, other) -> "DerivedFieldCombination": |
| 78 | + return DerivedFieldCombination([self, other], op=operator.mul) |
| 79 | + |
| 80 | + # Division (left side) |
| 81 | + def __truediv__(self, other) -> "DerivedFieldCombination": |
| 82 | + return DerivedFieldCombination([self, other], op=operator.truediv) |
| 83 | + |
| 84 | + def __rtruediv__(self, other) -> "DerivedFieldCombination": |
| 85 | + return DerivedFieldCombination([other, self], op=operator.truediv) |
| 86 | + |
| 87 | + # Addition (left and right side) |
| 88 | + def __add__(self, other) -> "DerivedFieldCombination": |
| 89 | + return DerivedFieldCombination([self, other], op=operator.add) |
| 90 | + |
| 91 | + def __radd__(self, other) -> "DerivedFieldCombination": |
| 92 | + return DerivedFieldCombination([self, other], op=operator.add) |
| 93 | + |
| 94 | + # Subtraction (left and right side) |
| 95 | + def __sub__(self, other) -> "DerivedFieldCombination": |
| 96 | + return DerivedFieldCombination([self, other], op=operator.sub) |
| 97 | + |
| 98 | + def __rsub__(self, other) -> "DerivedFieldCombination": |
| 99 | + return DerivedFieldCombination([other, self], op=operator.sub) |
| 100 | + |
| 101 | + # Unary minus |
| 102 | + def __neg__(self) -> "DerivedFieldCombination": |
| 103 | + return DerivedFieldCombination([self], op=operator.neg) |
| 104 | + |
| 105 | + # Comparison operators |
| 106 | + def __leq__(self, other) -> "DerivedFieldCombination": |
| 107 | + return DerivedFieldCombination([self, other], op=operator.le) |
| 108 | + |
| 109 | + def __lt__(self, other) -> "DerivedFieldCombination": |
| 110 | + return DerivedFieldCombination([self, other], op=operator.lt) |
| 111 | + |
| 112 | + def __geq__(self, other) -> "DerivedFieldCombination": |
| 113 | + return DerivedFieldCombination([self, other], op=operator.ge) |
| 114 | + |
| 115 | + def __gt__(self, other) -> "DerivedFieldCombination": |
| 116 | + return DerivedFieldCombination([self, other], op=operator.gt) |
| 117 | + |
| 118 | + # def __eq__(self, other) -> "DerivedFieldCombination": |
| 119 | + # return DerivedFieldCombination([self, other], op=operator.eq) |
| 120 | + |
| 121 | + def __ne__(self, other) -> "DerivedFieldCombination": |
| 122 | + return DerivedFieldCombination([self, other], op=operator.ne) |
| 123 | + |
| 124 | + |
| 125 | +class DerivedFieldCombination(DerivedFieldBase): |
| 126 | + sampling_type: str | None |
| 127 | + terms: list |
| 128 | + op: Callable | None |
| 129 | + |
| 130 | + def __init__(self, terms: list, op=None): |
| 131 | + if not terms: |
| 132 | + raise ValueError("DerivedFieldCombination requires at least one term.") |
| 133 | + |
| 134 | + # Make sure all terms have the same sampling type |
| 135 | + sampling_types = set() |
| 136 | + for term in terms: |
| 137 | + if isinstance(term, DerivedField): |
| 138 | + sampling_types.add(term.sampling_type) |
| 139 | + |
| 140 | + if len(sampling_types) > 1: |
| 141 | + raise ValueError( |
| 142 | + "All terms in a DerivedFieldCombination must " |
| 143 | + "have the same sampling type." |
| 144 | + ) |
| 145 | + self.sampling_type = sampling_types.pop() if sampling_types else None |
| 146 | + self.terms = terms |
| 147 | + self.op = op |
| 148 | + |
| 149 | + def __call__(self, field, data): |
| 150 | + """ |
| 151 | + Return the value of the field in a given data object. |
| 152 | + """ |
| 153 | + qties = [] |
| 154 | + for term in self.terms: |
| 155 | + if isinstance(term, DerivedField): |
| 156 | + qties.append(data[term.name]) |
| 157 | + elif isinstance(term, DerivedFieldCombination): |
| 158 | + qties.append(term(field, data)) |
| 159 | + else: |
| 160 | + qties.append(term) |
| 161 | + |
| 162 | + if len(qties) == 1: |
| 163 | + return self.op(qties[0]) |
| 164 | + else: |
| 165 | + return reduce(self.op, qties) |
| 166 | + |
| 167 | + def __repr__(self): |
| 168 | + return f"DerivedFieldCombination(terms={self.terms!r}, op={self.op!r})" |
| 169 | + |
| 170 | + def getDependentFields(self): |
| 171 | + fields = [] |
| 172 | + for term in self.terms: |
| 173 | + if isinstance(term, DerivedField): |
| 174 | + fields.append(term.name) |
| 175 | + elif isinstance(term, DerivedFieldCombination): |
| 176 | + fields.extend(term.getDependentFields()) |
| 177 | + else: |
| 178 | + continue |
| 179 | + return fields |
| 180 | + |
| 181 | + |
| 182 | +class DerivedField(DerivedFieldBase): |
63 | 183 | """
|
64 | 184 | This is the base class used to describe a cell-by-cell derived field.
|
65 | 185 |
|
@@ -499,128 +619,6 @@ def __copy__(self):
|
499 | 619 | nodal_flag=self.nodal_flag,
|
500 | 620 | )
|
501 | 621 |
|
502 |
| - def _operator( |
503 |
| - self, other: Union["DerivedField", float], op: Callable |
504 |
| - ) -> "DerivedField": |
505 |
| - my_units = self.ds.get_unit_from_registry(self.units) |
506 |
| - if isinstance(other, DerivedField): |
507 |
| - if self.sampling_type != other.sampling_type: |
508 |
| - raise TypeError( |
509 |
| - f"Cannot {op} fields with different sampling types: " |
510 |
| - f"{self.sampling_type} and {other.sampling_type}" |
511 |
| - ) |
512 |
| - |
513 |
| - def wrapped(field, data): |
514 |
| - return op(self(data), other(data)) |
515 |
| - |
516 |
| - other_name = other.name[1] |
517 |
| - other_units = self.ds.get_unit_from_registry(other.units) |
518 |
| - |
519 |
| - else: |
520 |
| - # Special case when passing (value, "unit") tuple |
521 |
| - if isinstance(other, tuple) and len(other) == 2: |
522 |
| - other = self.ds.quan(*other) |
523 |
| - |
524 |
| - def wrapped(field, data): |
525 |
| - return op(self(data), other) |
526 |
| - |
527 |
| - other_name = str(other) |
528 |
| - other_units = getattr(other, "units", self.ds.get_unit_from_registry("1")) |
529 |
| - |
530 |
| - if op in (operator.add, operator.sub, operator.eq): |
531 |
| - assert my_units.same_dimensions_as(other_units) |
532 |
| - new_units = my_units |
533 |
| - elif op in (operator.mul, operator.truediv): |
534 |
| - new_units = op(my_units, other_units) |
535 |
| - elif op in (operator.le, operator.lt, operator.ge, operator.gt, operator.ne): |
536 |
| - # Comparison yield unitless fields |
537 |
| - new_units = Unit("1") |
538 |
| - else: |
539 |
| - raise TypeError(f"Unsupported operator {op} for DerivedField") |
540 |
| - |
541 |
| - return DerivedField( |
542 |
| - name=(self.name[0], f"{self.name[1]}_{op.__name__}_{other_name}"), |
543 |
| - sampling_type=self.sampling_type, |
544 |
| - function=wrapped, |
545 |
| - units=new_units, |
546 |
| - ds=self.ds, |
547 |
| - ) |
548 |
| - |
549 |
| - # Multiplication (left and right side) |
550 |
| - def __mul__(self, other: Union["DerivedField", float]) -> "DerivedField": |
551 |
| - return self._operator(other, op=operator.mul) |
552 |
| - |
553 |
| - def __rmul__(self, other: Union["DerivedField", float]) -> "DerivedField": |
554 |
| - return self._operator(other, op=operator.mul) |
555 |
| - |
556 |
| - # Division (left side) |
557 |
| - def __truediv__(self, other: Union["DerivedField", float]) -> "DerivedField": |
558 |
| - return self._operator(other, op=operator.truediv) |
559 |
| - |
560 |
| - # Addition (left and right side) |
561 |
| - def __add__(self, other: Union["DerivedField", float]) -> "DerivedField": |
562 |
| - return self._operator(other, op=operator.add) |
563 |
| - |
564 |
| - def __radd__(self, other: Union["DerivedField", float]) -> "DerivedField": |
565 |
| - return self._operator(other, op=operator.add) |
566 |
| - |
567 |
| - # Subtraction (left and right side) |
568 |
| - def __sub__(self, other: Union["DerivedField", float]) -> "DerivedField": |
569 |
| - return self._operator(other, op=operator.sub) |
570 |
| - |
571 |
| - def __rsub__(self, other: Union["DerivedField", float]) -> "DerivedField": |
572 |
| - return self._operator(-other, op=operator.add) |
573 |
| - |
574 |
| - # Unary minus |
575 |
| - def __neg__(self) -> "DerivedField": |
576 |
| - def wrapped(field, data): |
577 |
| - return -self(data) |
578 |
| - |
579 |
| - return DerivedField( |
580 |
| - name=(self.name[0], f"neg_{self.name[1]}"), |
581 |
| - sampling_type=self.sampling_type, |
582 |
| - function=wrapped, |
583 |
| - units=self.units, |
584 |
| - ds=self.ds, |
585 |
| - ) |
586 |
| - |
587 |
| - # Division (right side, a bit more complex) |
588 |
| - def __rtruediv__(self, other: Union["DerivedField", float]) -> "DerivedField": |
589 |
| - units = self.ds.get_unit_from_registry(self.units) |
590 |
| - |
591 |
| - def wrapped(field, data): |
592 |
| - return 1 / self(data) |
593 |
| - |
594 |
| - inverse_self = DerivedField( |
595 |
| - name=(self.name[0], f"inverse_{self.name[1]}"), |
596 |
| - sampling_type=self.sampling_type, |
597 |
| - function=wrapped, |
598 |
| - units=units**-1, |
599 |
| - ds=self.ds, |
600 |
| - ) |
601 |
| - |
602 |
| - return inverse_self * other |
603 |
| - |
604 |
| - # Comparison operators |
605 |
| - def __leq__(self, other: Union["DerivedField", float]) -> "DerivedField": |
606 |
| - return self._operator(other, op=operator.le) |
607 |
| - |
608 |
| - def __lt__(self, other: Union["DerivedField", float]) -> "DerivedField": |
609 |
| - return self._operator(other, op=operator.lt) |
610 |
| - |
611 |
| - def __geq__(self, other: Union["DerivedField", float]) -> "DerivedField": |
612 |
| - return self._operator(other, op=operator.ge) |
613 |
| - |
614 |
| - def __gt__(self, other: Union["DerivedField", float]) -> "DerivedField": |
615 |
| - return self._operator(other, op=operator.gt) |
616 |
| - |
617 |
| - # Somehow, makes yt not work? |
618 |
| - # def __eq__(self, other: Union["DerivedField", float]) -> "DerivedField": |
619 |
| - # return self._operator(other, op=operator.eq) |
620 |
| - |
621 |
| - def __ne__(self, other: Union["DerivedField", float]) -> "DerivedField": |
622 |
| - return self._operator(other, op=operator.ne) |
623 |
| - |
624 | 622 |
|
625 | 623 | class FieldValidator:
|
626 | 624 | """
|
|
0 commit comments