Skip to content

Commit 05ce88e

Browse files
committed
api: fix arg processing for subsampling factor
1 parent 38c673b commit 05ce88e

File tree

2 files changed

+19
-12
lines changed

2 files changed

+19
-12
lines changed

devito/types/dimension.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
915915

916916
# Process subsampling factor
917917
fname = f"{name}f"
918-
if factor is None:
918+
if factor is None or factor == 1:
919919
self._factor = None
920920
elif is_number(factor):
921921
self._factor = int(factor)
@@ -966,9 +966,14 @@ def free_symbols(self):
966966
pass
967967
return retval
968968

969-
def _arg_values(self, interval, grid=None, **kwargs):
970-
# Parent dimension define the interval
971-
fact = self.factor
969+
def _arg_values(self, interval, grid=None, args=None, **kwargs):
970+
if self.symbolic_factor is not None:
971+
fname = self.symbolic_factor.name
972+
fact = kwargs.get(fname, args.get(fname, self.factor))
973+
else:
974+
# No factor
975+
return {}
976+
972977
toint = lambda x: math.ceil(x / fact)
973978
vals = {}
974979
try:
@@ -981,6 +986,9 @@ def _arg_values(self, interval, grid=None, **kwargs):
981986
except (KeyError, TypeError):
982987
pass
983988

989+
if self.symbolic_factor is not None:
990+
vals[self.symbolic_factor.name] = fact
991+
984992
return vals
985993

986994
def _arg_defaults(self, _min=None, size=None, alias=None):
@@ -990,11 +998,9 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
990998
# `factor` endpoints are legal, so we return them all. It's then
991999
# up to the caller to decide which one to pick upon reduction
9921000
dim = alias or self
993-
if dim.condition is not None or size is None or dim._factor is None:
994-
return defaults
995-
996-
factor = defaults[dim.symbolic_factor.name] = self.factor
997-
defaults[dim.parent.max_name] = range(0, factor*size - 1)
1001+
if dim.symbolic_factor is not None:
1002+
factor = defaults[dim.symbolic_factor.name] = self.factor
1003+
defaults[dim.parent.max_name] = range(0, factor*size - 1)
9981004

9991005
return defaults
10001006

tests/test_dimension.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from devito.ir import SymbolRegistry
1818
from devito.symbolics import indexify, retrieve_functions, IntDiv, INT
1919
from devito.types import Array, StencilDimension, Symbol
20+
from devito.types.basic import Scalar
2021
from devito.types.dimension import AffineIndexAccessFunction, Thickness
2122

2223

@@ -1012,9 +1013,9 @@ def test_issue_1592(self):
10121013
op = Operator(Eq(v.forward, v.dx))
10131014
op.apply(time=6)
10141015
exprs = FindNodes(Expression).visit(op)
1015-
assert exprs[-1].expr.lhs.indices[0] == IntDiv(time, time_sub.factor) + 1
1016-
assert time_sub.factor.data == 2
1017-
assert time_sub.factor.is_Constant
1016+
assert exprs[-1].expr.lhs.indices[0] == IntDiv(time, time_sub.symbolic_factor) + 1
1017+
assert time_sub.factor == 2
1018+
assert isinstance(time_sub.symbolic_factor, Scalar)
10181019

10191020
def test_issue_1753(self):
10201021
grid = Grid(shape=(3, 3, 3))

0 commit comments

Comments
 (0)