Skip to content

Commit 29fc0c9

Browse files
committed
api: update throughout for staggered setup
1 parent 5dd4f97 commit 29fc0c9

File tree

6 files changed

+68
-42
lines changed

6 files changed

+68
-42
lines changed

devito/finite_differences/derivative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def _eval_at(self, func):
405405
return self
406406
# For basic equation of the form f = Derivative(g, ...) we can just
407407
# compare staggering
408-
if self.expr.staggered == func.staggered:
408+
if self.expr.staggered == func.staggered and self.expr.is_Function:
409409
return self
410410

411411
x0 = func.indices_ref.getters

devito/finite_differences/rsfd.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import wraps
22

3-
from devito.types import NODE
43
from devito.types.dimension import StencilDimension
54
from .differentiable import Weights, DiffDerivative
65
from .tools import generate_indices, fd_weights_registry
@@ -101,12 +100,7 @@ def check_staggering(func):
101100
def wrapper(expr, dim, x0=None, expand=True):
102101
grid = expr.grid
103102
x0 = {k: v for k, v in x0.items() if k.is_Space}
104-
if expr.staggered is NODE or expr.staggered is None:
105-
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
106-
elif expr.staggered == grid.dimensions:
107-
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
108-
else:
109-
cond = False
103+
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
110104
if cond:
111105
return func(expr, dim, x0=x0, expand=expand)
112106
else:
@@ -117,7 +111,8 @@ def wrapper(expr, dim, x0=None, expand=True):
117111
@check_staggering
118112
def d45(expr, dim, x0=None, expand=True):
119113
"""
120-
RSFD approximation of the derivative of `expr` along `dim` at point `x0`.
114+
Rotated staggered grid finite-differences (RSFD) discretization
115+
of the derivative of `expr` along `dim` at point `x0`.
121116
122117
Parameters
123118
----------
@@ -132,7 +127,8 @@ def d45(expr, dim, x0=None, expand=True):
132127
"""
133128
# Make sure the grid supports RSFD
134129
if expr.grid.dim not in [2, 3]:
135-
raise ValueError('RSFD only supported in 2D and 3D')
130+
raise ValueError('Rotated staggered grid finite-differences (RSFD)'
131+
' only supported in 2D and 3D')
136132

137133
# Diagonals weights
138134
w = dir_weights[(dim.name, expr.grid.dim)]

devito/types/basic.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ def __new__(cls, *args, **kwargs):
699699
args, kwargs = cls.__args_setup__(*args, **kwargs)
700700

701701
# Extract the `indices`, as perhaps they're explicitly provided
702-
dimensions, indices, staggered = cls.__indices_setup__(*args, **kwargs)
702+
dimensions, indices = cls.__indices_setup__(*args, **kwargs)
703703

704704
# If it's an alias or simply has a different name, ignore `function`.
705705
# These cases imply the construction of a new AbstractFunction off
@@ -743,7 +743,6 @@ def __new__(cls, *args, **kwargs):
743743
# when executing __init_finalize__
744744
newobj._name = name
745745
newobj._dimensions = dimensions
746-
newobj._staggered = staggered
747746
newobj._shape = cls.__shape_setup__(**kwargs)
748747
newobj._dtype = cls.__dtype_setup__(**kwargs)
749748

@@ -926,11 +925,6 @@ def indices(self):
926925
"""The indices of the object."""
927926
return DimensionTuple(*self.args, getters=self.dimensions)
928927

929-
@property
930-
def staggered(self):
931-
"""The staggered indices of the object."""
932-
return DimensionTuple(*self._staggered, getters=self.dimensions)
933-
934928
@property
935929
def indices_ref(self):
936930
"""The reference indices of the object (indices at first creation)."""
@@ -1496,6 +1490,10 @@ def name(self):
14961490
return self.__class__.__name__
14971491

14981492
def _rebuild(self, *args, **kwargs):
1493+
# Plain `func` call (row, col, comps)
1494+
if not kwargs.keys() & self.__rkwargs__:
1495+
assert len(args) == 3
1496+
return self._new(*args, **kwargs)
14991497
# We need to rebuild the components with the new name then
15001498
# rebuild the matrix
15011499
newname = kwargs.pop('name', self.name)

devito/types/dense.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from devito.types.args import ArgProvider
2727
from devito.types.caching import CacheManager
2828
from devito.types.basic import AbstractFunction, Size
29-
from devito.types.utils import Buffer, DimensionTuple, NODE, CELL, host_layer
29+
from devito.types.utils import Buffer, DimensionTuple, NODE, CELL, host_layer, Staggering
3030

3131
__all__ = ['Function', 'TimeFunction', 'SubFunction', 'TempFunction']
3232

@@ -1010,6 +1010,10 @@ def _cache_meta(self):
10101010
def __init_finalize__(self, *args, **kwargs):
10111011
super().__init_finalize__(*args, **kwargs)
10121012

1013+
# Staggering
1014+
self._staggered = self.__staggered_setup__(self.dimensions,
1015+
staggered=kwargs.get('staggered'))
1016+
10131017
# Space order
10141018
space_order = kwargs.get('space_order', 1)
10151019
if isinstance(space_order, int):
@@ -1042,7 +1046,7 @@ def __fd_setup__(self):
10421046

10431047
@cached_property
10441048
def _fd_priority(self):
1045-
return 1 if self.staggered in [NODE, None] else 2
1049+
return 1 if self.staggered.on_node else 2
10461050

10471051
@property
10481052
def is_parameter(self):
@@ -1059,26 +1063,33 @@ def _eval_at(self, func):
10591063
return self
10601064

10611065
@classmethod
1062-
def __staggered_setup__(cls, dimensions, **kwargs):
1066+
def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):
10631067
"""
10641068
Setup staggering-related metadata. This method assigns:
10651069
10661070
* 0 to non-staggered dimensions;
10671071
* 1 to staggered dimensions.
10681072
"""
1069-
stagg = kwargs.get('staggered', None)
1070-
if stagg is CELL:
1071-
staggered = (sympy.S.One for d in dimensions)
1072-
elif stagg in [None, NODE]:
1073-
staggered = (sympy.S.Zero for d in dimensions)
1074-
elif all(is_integer(s) for s in as_tuple(stagg)):
1073+
if not staggered:
1074+
processed = ()
1075+
elif staggered is CELL:
1076+
processed = (sympy.S.One,)*len(dimensions)
1077+
elif staggered is NODE:
1078+
processed = (sympy.S.Zero,)*len(dimensions)
1079+
elif all(is_integer(s) for s in as_tuple(staggered)):
10751080
# Staggering is already a tuple likely from rebuild
1076-
assert len(stagg) == len(dimensions)
1077-
return tuple(stagg)
1081+
assert len(staggered) == len(dimensions)
1082+
processed = staggered
10781083
else:
1079-
staggered = (sympy.S.One if d in as_tuple(stagg) else sympy.S.Zero
1080-
for d in dimensions)
1081-
return tuple(staggered)
1084+
processed = []
1085+
for d in dimensions:
1086+
if d in as_tuple(staggered):
1087+
processed.append(sympy.S.One)
1088+
elif -d in as_tuple(staggered):
1089+
processed.append(sympy.S.NegativeOne)
1090+
else:
1091+
processed.append(sympy.S.Zero)
1092+
return tuple(processed)
10821093

10831094
@classmethod
10841095
def __indices_setup__(cls, *args, **kwargs):
@@ -1097,14 +1108,27 @@ def __indices_setup__(cls, *args, **kwargs):
10971108
assert len(args) == len(dimensions)
10981109
staggered_indices = tuple(args)
10991110
else:
1100-
# Staggered indices
1101-
staggered_indices = (d + i * d.spacing / 2
1102-
for d, i in zip(dimensions, staggered))
1103-
return tuple(dimensions), tuple(staggered_indices), staggered
1111+
if not staggered:
1112+
staggered_indices = (d for d in dimensions)
1113+
else:
1114+
# Staggered indices
1115+
staggered_indices = (d + i * d.spacing / 2
1116+
for d, i in zip(dimensions, staggered))
1117+
return tuple(dimensions), tuple(staggered_indices)
1118+
1119+
@property
1120+
def staggered(self):
1121+
"""The staggered indices of the object."""
1122+
if self._staggered:
1123+
return Staggering(*self._staggered, getters=self.dimensions)
1124+
else:
1125+
return Staggering(getters=self.dimensions)
11041126

11051127
@property
11061128
def is_Staggered(self):
1107-
return self.staggered is not None
1129+
if not self.staggered:
1130+
return False
1131+
return True
11081132

11091133
@classmethod
11101134
def __shape_setup__(cls, **kwargs):
@@ -1392,7 +1416,6 @@ def __fd_setup__(self):
13921416
@classmethod
13931417
def __indices_setup__(cls, *args, **kwargs):
13941418
dimensions = kwargs.get('dimensions')
1395-
staggered = kwargs.get('staggered')
13961419

13971420
if dimensions is None:
13981421
save = kwargs.get('save')
@@ -1407,7 +1430,7 @@ def __indices_setup__(cls, *args, **kwargs):
14071430
dimensions.insert(cls._time_position, time_dim)
14081431

14091432
return Function.__indices_setup__(
1410-
*args, dimensions=dimensions, staggered=staggered
1433+
*args, dimensions=dimensions, staggered=kwargs.get('staggered')
14111434
)
14121435

14131436
@classmethod
@@ -1446,7 +1469,7 @@ def __shape_setup__(cls, **kwargs):
14461469

14471470
@cached_property
14481471
def _fd_priority(self):
1449-
return 2.1 if self.staggered in [NODE, None] else 2.2
1472+
return 2.1 if self.staggered.on_node else 2.2
14501473

14511474
@property
14521475
def time_order(self):
@@ -1600,7 +1623,7 @@ def __indices_setup__(cls, **kwargs):
16001623
# Sanity check
16011624
assert not any(d.is_NonlinearDerived for d in dimensions)
16021625

1603-
return dimensions, dimensions, (sympy.S.Zero for _ in dimensions)
1626+
return dimensions, dimensions
16041627

16051628
def __halo_setup__(self, **kwargs):
16061629
pointer_dim = kwargs.get('pointer_dim')

devito/types/tensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,13 @@ class TensorFunction(AbstractTensor):
6969
_class_priority = 10
7070
_op_priority = Differentiable._op_priority + 1.
7171

72+
__rkwargs__ = AbstractTensor.__rkwargs__ + ('dimensions', 'space_order')
73+
7274
def __init_finalize__(self, *args, **kwargs):
7375
super().__init_finalize__(*args, **kwargs)
7476
grid = kwargs.get('grid')
7577
dimensions = kwargs.get('dimensions')
76-
inds, _, _ = Function.__indices_setup__(grid=grid,
77-
dimensions=dimensions)
78+
inds, _ = Function.__indices_setup__(grid=grid, dimensions=dimensions)
7879
self._space_dimensions = inds
7980

8081
@classmethod

devito/types/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ctypes import POINTER, Structure
2+
from functools import cached_property
23

34
from devito.tools import EnrichedTuple, Tag
45
# Additional Function-related APIs
@@ -31,6 +32,13 @@ def __getitem_hook__(self, dim):
3132
raise KeyError
3233

3334

35+
class Staggering(DimensionTuple):
36+
37+
@cached_property
38+
def on_node(self):
39+
return not self or all(s == 0 for s in self)
40+
41+
3442
class IgnoreDimSort(tuple):
3543
"""A tuple subclass used to wrap the implicit_dims to indicate
3644
that the topological sort of other dimensions should not occur."""

0 commit comments

Comments
 (0)