Skip to content

Commit 0c61b43

Browse files
committed
compiler: Only relax upper dspace in case of save
1 parent c8cc244 commit 0c61b43

File tree

6 files changed

+66
-39
lines changed

6 files changed

+66
-39
lines changed

devito/ir/clusters/cluster.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -357,24 +357,16 @@ def dspace(self):
357357
# Construct the `intervals` of the DataSpace, that is a global,
358358
# Dimension-centric view of the data space
359359
intervals = IntervalGroup.generate('union', *parts.values())
360+
360361
# E.g., `db0 -> time`, but `xi NOT-> x`
361362
intervals = intervals.promote(lambda d: not d.is_Sub)
362363
intervals = intervals.zero(set(intervals.dimensions) - oobs)
363364

364-
# Intersect with intervals from buffered dimensions. Unions of
365-
# buffered dimension intervals may result in shrinking time size
366-
try:
367-
proc = []
368-
for f, v in parts.items():
369-
if f.save:
370-
for i in v:
371-
if i.dim.is_Time:
372-
proc.append(intervals[i.dim].intersection(i))
373-
else:
374-
proc.append(intervals[i.dim])
375-
intervals = IntervalGroup(proc)
376-
except AttributeError:
377-
pass
365+
# Buffered TimeDimensions should not shirnk their upper time offset
366+
for f, v in parts.items():
367+
if f.is_TimeFunction:
368+
if f.save and not f.time_dim.is_Conditional:
369+
intervals = intervals.ceil(v[f.time_dim])
378370

379371
return DataSpace(intervals, parts)
380372

devito/ir/support/space.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,11 @@ def negate(self):
259259
def zero(self):
260260
return Interval(self.dim, 0, 0, self.stamp)
261261

262+
def ceil(self, o):
263+
if o.is_Null:
264+
return self._rebuild()
265+
return Interval(self.dim, self.lower, o.upper, self.stamp)
266+
262267
def flip(self):
263268
return Interval(self.dim, self.upper, self.lower, self.stamp)
264269

@@ -492,6 +497,11 @@ def zero(self, d=None):
492497

493498
return IntervalGroup(intervals, relations=self.relations, mode=self.mode)
494499

500+
def ceil(self, o=None):
501+
d = self.dimensions if o is None else as_tuple(o.dim)
502+
return IntervalGroup([i.ceil(o) if i.dim in d else i for i in self],
503+
relations=self.relations)
504+
495505
def lift(self, d=None, v=None):
496506
d = set(self.dimensions if d is None else as_tuple(d))
497507
intervals = [i.lift(v) if i.dim._defines & d else i for i in self]

tests/test_buffering.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -752,24 +752,3 @@ def test_stencil_issue_1915_v2(subdomain):
752752
op1.apply(time_M=nt-2, u=u1)
753753

754754
assert np.all(u.data == u1.data)
755-
756-
757-
def test_default_timeM():
758-
"""
759-
MFE for issue #2235
760-
"""
761-
grid = Grid(shape=(4, 4))
762-
763-
u = TimeFunction(name='u', grid=grid)
764-
usave = TimeFunction(name='usave', grid=grid, save=5)
765-
766-
eqns = [Eq(u.forward, u + 1),
767-
Eq(usave, u)]
768-
769-
op = Operator(eqns)
770-
771-
assert op.arguments()['time_M'] == 4
772-
773-
op.apply()
774-
775-
assert all(np.all(usave.data[i] == i) for i in range(4))

tests/test_checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
@switchconfig(log_level='WARNING')
13-
def test_segmented_incremment():
13+
def test_segmented_increment():
1414
"""
1515
Test for segmented operator execution of a one-sided first order
1616
function (increment). The corresponding set of stencil offsets in

tests/test_dimension.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,25 @@ def test_modulo_dims_generation_v2(self):
210210
assert np.all(f.data[3] == 2)
211211
assert np.all(f.data[4] == 4)
212212

213+
def test_default_timeM(self):
214+
"""
215+
MFE for issue #2235
216+
"""
217+
grid = Grid(shape=(4, 4))
218+
219+
u = TimeFunction(name='u', grid=grid)
220+
usave = TimeFunction(name='usave', grid=grid, save=5)
221+
222+
eqns = [Eq(u.forward, u + 1),
223+
Eq(usave, u)]
224+
225+
op = Operator(eqns)
226+
227+
assert op.arguments()['time_M'] == 4
228+
op.apply()
229+
230+
assert all(np.all(usave.data[i] == i) for i in range(4))
231+
213232

214233
class TestSubDimension(object):
215234

@@ -760,7 +779,7 @@ def test_basic(self):
760779

761780
eqns = [Eq(u.forward, u + 1.), Eq(u2.forward, u2 + 1.), Eq(usave, u)]
762781
op = Operator(eqns)
763-
op.apply()
782+
op.apply(time_M=nt-2)
764783
assert np.all(np.allclose(u.data[(nt-1) % 3], nt-1))
765784
assert np.all([np.allclose(u2.data[i], i) for i in range(nt)])
766785
assert np.all([np.allclose(usave.data[i], i*factor)

tests/test_operator.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,11 +2010,38 @@ def test_indirection(self):
20102010

20112011
op = Operator(eqns)
20122012

2013-
assert op._dspace[time].lower == 1
2013+
assert op._dspace[time].lower == 0
20142014
assert op._dspace[time].upper == 1
20152015
assert op.arguments()['time_M'] == nt - 2
20162016

2017-
op()
2017+
op.apply()
20182018

20192019
assert np.all(f.data[0] == 0.)
20202020
assert np.all(f.data[i] == 3. for i in range(1, 10))
2021+
2022+
def test_indirection_v2(self):
2023+
nt = 10
2024+
grid = Grid(shape=(4, 4))
2025+
time = grid.time_dim
2026+
x, y = grid.dimensions
2027+
2028+
f = TimeFunction(name='f', grid=grid, save=nt)
2029+
g = TimeFunction(name='g', grid=grid)
2030+
2031+
idx = time
2032+
s = Indirection(name='ofs0', mapped=idx)
2033+
2034+
eqns = [
2035+
Eq(s, idx),
2036+
Eq(f[s, x, y], g + 3.)
2037+
]
2038+
2039+
op = Operator(eqns)
2040+
2041+
assert op._dspace[time].lower == 0
2042+
assert op._dspace[time].upper == 0
2043+
assert op.arguments()['time_M'] == nt - 1
2044+
2045+
op.apply()
2046+
2047+
assert np.all(f.data[i] == 3. for i in range(1, 10))

0 commit comments

Comments
 (0)