Skip to content

Commit fefcfba

Browse files
authored
Merge pull request #2583 from devitocodes/tens-rebuild
api: Fix staggering setup and tensor rebuilding
2 parents 7a6b725 + 9c9fe55 commit fefcfba

File tree

9 files changed

+386
-245
lines changed

9 files changed

+386
-245
lines changed

devito/finite_differences/derivative.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def _eval_at(self, func):
405405
return self
406406
# For basic equation of the form f = Derivative(g, ...) we can just
407407
# compare staggering
408-
if self.expr.staggered == func.staggered:
408+
if self.expr.staggered == func.staggered and self.expr.is_Function:
409409
return self
410410

411411
x0 = func.indices_ref.getters

devito/finite_differences/operators.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,19 @@ def diag(func, size=None):
165165
size of the diagonal matrix (size x size).
166166
Defaults to the number of spatial dimensions when unspecified
167167
"""
168+
from devito.types.tensor import TensorFunction, TensorTimeFunction
169+
if isinstance(func, TensorFunction):
170+
if func.is_TensorValued:
171+
return func._new(*func.shape, lambda i, j: func[i, i] if i == j else 0)
172+
else:
173+
n = func.shape[0]
174+
return func._new(n, n, lambda i, j: func[i] if i == j else 0)
175+
168176
dim = size or len(func.dimensions)
169177
dim = dim-1 if func.is_TimeDependent else dim
170178
to = getattr(func, 'time_order', 0)
171179

172-
from devito.types.tensor import TensorFunction, TensorTimeFunction
173180
tens_func = TensorTimeFunction if func.is_TimeDependent else TensorFunction
174-
175181
comps = [[func if i == j else 0 for i in range(dim)] for j in range(dim)]
176182
return tens_func(name='diag', grid=func.grid, space_order=func.space_order,
177183
components=comps, time_order=to, diagonal=True)

devito/finite_differences/rsfd.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from functools import wraps
22

3-
from devito.types import NODE
43
from devito.types.dimension import StencilDimension
54
from .differentiable import Weights, DiffDerivative
65
from .tools import generate_indices, fd_weights_registry
@@ -101,12 +100,7 @@ def check_staggering(func):
101100
def wrapper(expr, dim, x0=None, expand=True):
102101
grid = expr.grid
103102
x0 = {k: v for k, v in x0.items() if k.is_Space}
104-
if expr.staggered is NODE or expr.staggered is None:
105-
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
106-
elif expr.staggered == grid.dimensions:
107-
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
108-
else:
109-
cond = False
103+
cond = x0 == {} or x0 == all_staggered(grid) or x0 == grid_node(grid)
110104
if cond:
111105
return func(expr, dim, x0=x0, expand=expand)
112106
else:
@@ -117,7 +111,8 @@ def wrapper(expr, dim, x0=None, expand=True):
117111
@check_staggering
118112
def d45(expr, dim, x0=None, expand=True):
119113
"""
120-
RSFD approximation of the derivative of `expr` along `dim` at point `x0`.
114+
Rotated staggered grid finite-differences (RSFD) discretization
115+
of the derivative of `expr` along `dim` at point `x0`.
121116
122117
Parameters
123118
----------
@@ -132,7 +127,8 @@ def d45(expr, dim, x0=None, expand=True):
132127
"""
133128
# Make sure the grid supports RSFD
134129
if expr.grid.dim not in [2, 3]:
135-
raise ValueError('RSFD only supported in 2D and 3D')
130+
raise ValueError('Rotated staggered grid finite-differences (RSFD)'
131+
' only supported in 2D and 3D')
136132

137133
# Diagonals weights
138134
w = dir_weights[(dim.name, expr.grid.dim)]

0 commit comments

Comments
 (0)