Skip to content

Commit 6828f53

Browse files
committed
api: cache subsampling factors to avoid duplicates
1 parent fefcfba commit 6828f53

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

devito/types/dimension.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from devito.tools import Pickable, is_integer, memoized_meth
1313
from devito.types.args import ArgProvider
1414
from devito.types.basic import Symbol, DataSymbol, Scalar
15+
from devito.types.caching import Cached
1516
from devito.types.constant import Constant
1617

1718

@@ -822,6 +823,12 @@ def bound_symbols(self):
822823
return self.parent.bound_symbols
823824

824825

826+
class SubsamplingFactor(Constant, Cached):
827+
828+
__hash__ = sympy.Symbol.__hash__
829+
_cache_key = Symbol._cache_key
830+
831+
825832
class ConditionalDimension(DerivedDimension):
826833

827834
"""
@@ -913,7 +920,8 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
913920
if factor is None or factor == 1:
914921
self._factor = None
915922
elif is_integer(factor):
916-
self._factor = Constant(name="%sf" % name, value=factor, dtype=np.int32)
923+
self._factor = SubsamplingFactor(name=f"{name}f", value=factor,
924+
dtype=np.int32)
917925
elif factor.is_Constant and is_integer(factor.data):
918926
self._factor = factor
919927
else:

tests/test_dimension.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from itertools import product
2+
from copy import deepcopy
23

34
import numpy as np
45
from sympy import And, Or
@@ -1896,6 +1897,24 @@ def test_cond_notime(self):
18961897
op(time_m=1, time_M=nt-1, dt=1)
18971898
assert norm(g, order=1) == norm(sum(usaved, dims=time_under), order=1)
18981899

1900+
def test_cond_copy(self):
1901+
grid = Grid((11, 11, 11))
1902+
time = grid.time_dim
1903+
1904+
cd = ConditionalDimension(name='tsub', parent=time, factor=5)
1905+
u = TimeFunction(name='u', grid=grid, space_order=4, time_order=2, save=Buffer(2))
1906+
u1 = TimeFunction(name='u1', grid=grid, space_order=0,
1907+
time_order=0, save=5, time_dim=cd)
1908+
u2 = TimeFunction(name='u2', grid=grid, space_order=0,
1909+
time_order=0, save=5, time_dim=cd)
1910+
1911+
# Mimic what happens when an operator is copied
1912+
u12 = deepcopy(u1)
1913+
u22 = deepcopy(u2)
1914+
1915+
op = Operator([Eq(u.forward, u.laplace), Eq(u12, u), Eq(u22, u)])
1916+
assert len([p for p in op.parameters if p.name == 'tsubf']) == 1
1917+
18991918

19001919
class TestCustomDimension:
19011920

0 commit comments

Comments
 (0)