Skip to content

Commit f59a4fe

Browse files
committed
api: improve backward compatibility of new factor
1 parent 33a0362 commit f59a4fe

File tree

3 files changed

+39
-5
lines changed

3 files changed

+39
-5
lines changed

devito/deprecations.py

+7
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,12 @@ def subdomain_warn(self):
2727
DeprecationWarning, stacklevel=2)
2828
return
2929

30+
@cached_property
31+
def constant_factor_warn(self):
32+
warn("Using a `Constant` as a factor when creating a ConditionalDimension"
33+
" is deprecated. Use an integer instead.",
34+
DeprecationWarning, stacklevel=2)
35+
return
36+
3037

3138
deprecations = DevitoDeprecation()

devito/types/dimension.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import numpy as np
88

99
from devito.data import LEFT, RIGHT
10+
from devito.deprecations import deprecations
1011
from devito.exceptions import InvalidArgument
1112
from devito.logger import debug
1213
from devito.tools import Pickable, is_integer, is_number, memoized_meth
1314
from devito.types.args import ArgProvider
1415
from devito.types.basic import Symbol, DataSymbol, Scalar
16+
from devito.types.constant import Constant
1517

1618

1719
__all__ = ['Dimension', 'SpaceDimension', 'TimeDimension', 'DefaultDimension',
@@ -920,8 +922,9 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
920922
elif is_number(factor):
921923
self._factor = int(factor)
922924
elif factor.is_Constant:
923-
self._factor = factor.data
924-
fname = factor.name
925+
deprecations.constant_factor_warn
926+
self._factor = factor
927+
symbolic_factor = factor
925928
else:
926929
raise ValueError("factor must be an integer or integer Constant")
927930

@@ -937,7 +940,14 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
937940

938941
@property
939942
def spacing(self):
940-
return self.factor * self.parent.spacing
943+
return self.factor_data * self.parent.spacing
944+
945+
@property
946+
def factor_data(self):
947+
if isinstance(self.factor, Constant):
948+
return self.factor.data
949+
else:
950+
return self.factor
941951

942952
@property
943953
def factor(self):
@@ -970,7 +980,7 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
970980
if self.symbolic_factor is not None:
971981
fname = self.symbolic_factor.name
972982
args = args or {}
973-
fact = kwargs.get(fname, args.get(fname, self.factor))
983+
fact = kwargs.get(fname, args.get(fname, self.factor_data))
974984
else:
975985
# No factor
976986
return {}
@@ -1000,7 +1010,7 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
10001010
# up to the caller to decide which one to pick upon reduction
10011011
dim = alias or self
10021012
if dim.symbolic_factor is not None:
1003-
factor = defaults[dim.symbolic_factor.name] = self.factor
1013+
factor = defaults[dim.symbolic_factor.name] = self.factor_data
10041014
defaults[dim.parent.max_name] = range(0, factor*size - 1)
10051015

10061016
return defaults

tests/test_dimension.py

+17
Original file line numberDiff line numberDiff line change
@@ -1916,6 +1916,23 @@ def test_cond_copy(self):
19161916
op = Operator([Eq(u.forward, u.laplace), Eq(u12, u), Eq(u22, u)])
19171917
assert len([p for p in op.parameters if p.name == 'tsubf']) == 1
19181918

1919+
def test_const_factor(self):
1920+
grid = Grid(shape=(4, 4))
1921+
time = grid.time_dim
1922+
1923+
f1 = 4
1924+
f2 = Constant(name='f2', dtype=np.int32, value=4)
1925+
t1 = ConditionalDimension('t_sub', parent=time, factor=f1)
1926+
t2 = ConditionalDimension('t_sub2', parent=time, factor=f2)
1927+
1928+
assert isinstance(t1.symbolic_factor, Scalar)
1929+
assert t1.factor == f1
1930+
1931+
assert t2.symbolic_factor.is_Constant
1932+
assert t2.factor == f2
1933+
assert t2.factor.data == f1
1934+
assert t2.spacing == t1.spacing
1935+
19191936

19201937
class TestCustomDimension:
19211938

0 commit comments

Comments
 (0)