Skip to content

Commit 38c673b

Browse files
committed
api: revamp subsampling factor
1 parent 3e730e0 commit 38c673b

File tree

5 files changed

+34
-30
lines changed

5 files changed

+34
-30
lines changed

devito/ir/equations/equation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def __new__(cls, *args, **kwargs):
220220
index = d.index
221221
if d.condition is not None and d in expr.free_symbols:
222222
index = index - relational_min(d.condition, d.parent)
223-
expr = uxreplace(expr, {d: IntDiv(index, d.factor)})
223+
expr = uxreplace(expr, {d: IntDiv(index, d.symbolic_factor)})
224224

225225
conditionals = frozendict(conditionals)
226226

devito/ir/support/guards.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class GuardFactor(Guard, CondEq, Pickable):
4747
def __new__(cls, d, **kwargs):
4848
assert d.is_Conditional
4949

50-
obj = super().__new__(cls, d.parent % d.factor, 0)
50+
obj = super().__new__(cls, d.parent % d.symbolic_factor, 0)
5151
obj.d = d
5252

5353
return obj
@@ -129,7 +129,7 @@ def __new__(cls, d, direction, **kwargs):
129129
p1 = d.root.symbolic_max
130130

131131
if d.is_Conditional:
132-
v = d.factor
132+
v = d.symbolic_factor
133133
# Round `p0 + 1` up to the nearest multiple of `v`
134134
p0 = Mul((((p0 + 1) + v - 1) / v), v, evaluate=False)
135135
else:
@@ -140,7 +140,7 @@ def __new__(cls, d, direction, **kwargs):
140140
p1 = d.root
141141

142142
if d.is_Conditional:
143-
v = d.factor
143+
v = d.symbolic_factor
144144
# Round `p1 - 1` down to the nearest sub-multiple of `v`
145145
# NOTE: we use ABS to make sure we handle negative values properly.
146146
# Once `p1 - 1` is negative (e.g. `iteration=time - 1` and `time=0`),

devito/symbolics/extended_sympy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class IntDiv(sympy.Expr):
8484
def __new__(cls, lhs, rhs, params=None):
8585
if rhs == 0:
8686
raise ValueError("Cannot divide by 0")
87-
elif rhs == 1:
87+
elif rhs == 1 or rhs is None:
8888
return lhs
8989

9090
if not is_integer(rhs):

devito/types/dimension.py

+28-24
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
from devito.data import LEFT, RIGHT
1010
from devito.exceptions import InvalidArgument
1111
from devito.logger import debug
12-
from devito.tools import Pickable, is_integer, memoized_meth
12+
from devito.tools import Pickable, is_integer, is_number, memoized_meth
1313
from devito.types.args import ArgProvider
1414
from devito.types.basic import Symbol, DataSymbol, Scalar
15-
from devito.types.caching import Cached
16-
from devito.types.constant import Constant
1715

1816

1917
__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
@@ -823,10 +821,8 @@ def bound_symbols(self):
823821
return self.parent.bound_symbols
824822

825823

826-
class SubsamplingFactor(Constant, Cached):
827-
828-
__hash__ = sympy.Symbol.__hash__
829-
_cache_key = Symbol._cache_key
824+
class SubsamplingFactor(Scalar):
825+
pass
830826

831827

832828
class ConditionalDimension(DerivedDimension):
@@ -905,40 +901,52 @@ class ConditionalDimension(DerivedDimension):
905901
is_NonlinearDerived = True
906902
is_Conditional = True
907903

908-
__rkwargs__ = DerivedDimension.__rkwargs__ + ('factor', 'condition', 'indirect')
904+
__rkwargs__ = DerivedDimension.__rkwargs__ + \
905+
('symbolic_factor', 'factor', 'condition', 'indirect')
909906

910907
def __init_finalize__(self, name, parent=None, factor=None, condition=None,
911-
indirect=False, **kwargs):
908+
indirect=False, symbolic_factor=None, **kwargs):
912909
# `parent=None` degenerates to a ConditionalDimension outside of
913910
# any iteration space
914911
if parent is None:
915912
parent = BOTTOM
916913

917914
super().__init_finalize__(name, parent)
918915

919-
# Always make the factor symbolic to allow overrides with different factor.
920-
if factor is None or factor == 1:
916+
# Process subsampling factor
917+
fname = f"{name}f"
918+
if factor is None:
921919
self._factor = None
922-
elif is_integer(factor):
923-
self._factor = SubsamplingFactor(name=f"{name}f", value=factor,
924-
dtype=np.int32)
925-
elif factor.is_Constant and is_integer(factor.data):
926-
self._factor = factor
920+
elif is_number(factor):
921+
self._factor = int(factor)
922+
elif factor.is_Constant:
923+
self._factor = factor.data
924+
fname = factor.name
927925
else:
928926
raise ValueError("factor must be an integer or integer Constant")
929927

928+
if self._factor is not None:
929+
# Always make the factor symbolic to allow overrides with different factor.
930+
self._symbolic_factor = symbolic_factor or \
931+
SubsamplingFactor(name=fname, dtype=np.int32, is_const=True)
932+
else:
933+
self._symbolic_factor = None
934+
930935
self._condition = condition
931936
self._indirect = indirect
932937

933938
@property
934939
def spacing(self):
935-
s = self._factor.data if self._factor is not None else 1
936-
return s * self.parent.spacing
940+
return self.factor * self.parent.spacing
937941

938942
@property
939943
def factor(self):
940944
return self._factor if self._factor is not None else 1
941945

946+
@property
947+
def symbolic_factor(self):
948+
return self._symbolic_factor
949+
942950
@property
943951
def condition(self):
944952
return self._condition
@@ -960,7 +968,7 @@ def free_symbols(self):
960968

961969
def _arg_values(self, interval, grid=None, **kwargs):
962970
# Parent dimension define the interval
963-
fact = self._factor.data if self._factor is not None else 1
971+
fact = self.factor
964972
toint = lambda x: math.ceil(x / fact)
965973
vals = {}
966974
try:
@@ -984,12 +992,8 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
984992
dim = alias or self
985993
if dim.condition is not None or size is None or dim._factor is None:
986994
return defaults
987-
try:
988-
# Is it a symbolic factor?
989-
factor = defaults[dim._factor.name] = self._factor.data
990-
except AttributeError:
991-
factor = dim._factor
992995

996+
factor = defaults[dim.symbolic_factor.name] = self.factor
993997
defaults[dim.parent.max_name] = range(0, factor*size - 1)
994998

995999
return defaults

devito/types/grid.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def spacing_map(self):
305305
# Special case subsampling: `Grid.dimensions` -> (xb, yb, zb)`
306306
# where `xb, yb, zb` are ConditionalDimensions whose parents
307307
# are SpaceDimensions
308-
mapper[d.root.spacing] = s/self.dtype(d.factor.data)
308+
mapper[d.root.spacing] = s/self.dtype(d.factor)
309309
elif d.is_Space:
310310
# Typical case: `Grid.dimensions` -> (x, y, z)` where `x, y, z` are
311311
# the SpaceDimensions

0 commit comments

Comments
 (0)