Skip to content

Commit 9c9fe55

Browse files
committed
api: add support for diag(tensor) and diag(vector)
1 parent 29fc0c9 commit 9c9fe55

File tree

4 files changed

+45
-14
lines changed

4 files changed

+45
-14
lines changed

devito/finite_differences/operators.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,19 @@ def diag(func, size=None):
165165
size of the diagonal matrix (size x size).
166166
Defaults to the number of spatial dimensions when unspecified
167167
"""
168+
from devito.types.tensor import TensorFunction, TensorTimeFunction
169+
if isinstance(func, TensorFunction):
170+
if func.is_TensorValued:
171+
return func._new(*func.shape, lambda i, j: func[i, i] if i == j else 0)
172+
else:
173+
n = func.shape[0]
174+
return func._new(n, n, lambda i, j: func[i] if i == j else 0)
175+
168176
dim = size or len(func.dimensions)
169177
dim = dim-1 if func.is_TimeDependent else dim
170178
to = getattr(func, 'time_order', 0)
171179

172-
from devito.types.tensor import TensorFunction, TensorTimeFunction
173180
tens_func = TensorTimeFunction if func.is_TimeDependent else TensorFunction
174-
175181
comps = [[func if i == j else 0 for i in range(dim)] for j in range(dim)]
176182
return tens_func(name='diag', grid=func.grid, space_order=func.space_order,
177183
components=comps, time_order=to, diagonal=True)

devito/types/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1492,7 +1492,9 @@ def name(self):
14921492
def _rebuild(self, *args, **kwargs):
14931493
# Plain `func` call (row, col, comps)
14941494
if not kwargs.keys() & self.__rkwargs__:
1495-
assert len(args) == 3
1495+
if len(args) != 3:
1496+
raise ValueError("Invalid number of arguments, expected nrow, ncol, "
1497+
"list of components")
14961498
return self._new(*args, **kwargs)
14971499
# We need to rebuild the components with the new name then
14981500
# rebuild the matrix

devito/types/dense.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,7 +1089,7 @@ def __staggered_setup__(cls, dimensions, staggered=None, **kwargs):
10891089
processed.append(sympy.S.NegativeOne)
10901090
else:
10911091
processed.append(sympy.S.Zero)
1092-
return tuple(processed)
1092+
return Staggering(*processed, getters=dimensions)
10931093

10941094
@classmethod
10951095
def __indices_setup__(cls, *args, **kwargs):
@@ -1109,26 +1109,20 @@ def __indices_setup__(cls, *args, **kwargs):
11091109
staggered_indices = tuple(args)
11101110
else:
11111111
if not staggered:
1112-
staggered_indices = (d for d in dimensions)
1112+
staggered_indices = dimensions
11131113
else:
1114-
# Staggered indices
11151114
staggered_indices = (d + i * d.spacing / 2
11161115
for d, i in zip(dimensions, staggered))
11171116
return tuple(dimensions), tuple(staggered_indices)
11181117

11191118
@property
11201119
def staggered(self):
11211120
"""The staggered indices of the object."""
1122-
if self._staggered:
1123-
return Staggering(*self._staggered, getters=self.dimensions)
1124-
else:
1125-
return Staggering(getters=self.dimensions)
1121+
return self._staggered
11261122

11271123
@property
11281124
def is_Staggered(self):
1129-
if not self.staggered:
1130-
return False
1131-
return True
1125+
return bool(self.staggered)
11321126

11331127
@classmethod
11341128
def __shape_setup__(cls, **kwargs):

tests/test_tensors.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import pytest
66

77
from devito import VectorFunction, TensorFunction, VectorTimeFunction, TensorTimeFunction
8-
from devito import Grid, Function, TimeFunction, Dimension, Eq, div, grad, curl, laplace
8+
from devito import (
9+
Grid, Function, TimeFunction, Dimension, Eq, div, grad, curl, laplace, diag
10+
)
911
from devito.symbolics import retrieve_derivatives
1012
from devito.types import NODE
1113

@@ -465,3 +467,30 @@ def test_rebuild(func1):
465467
assert j.name == i.name
466468
assert j.grid == i.grid
467469
assert j.dimensions == tuple(new_dims)
470+
471+
472+
@pytest.mark.parametrize('func1', [Function, TimeFunction,
473+
TensorFunction, TensorTimeFunction,
474+
VectorFunction, VectorTimeFunction])
475+
def test_diag(func1):
476+
grid = Grid(tuple([5]*3))
477+
f1 = func1(name="f1", grid=grid)
478+
479+
f2 = diag(f1)
480+
assert isinstance(f2, TensorFunction)
481+
if f1.is_TimeDependent:
482+
assert f2.is_TimeDependent
483+
print(f2)
484+
assert f2.shape == (3, 3)
485+
# Vector input
486+
if isinstance(f1, VectorFunction):
487+
assert all(f2[i, i] == f1[i] for i in range(3))
488+
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
489+
# Tensor input
490+
elif isinstance(f1, TensorFunction):
491+
assert all(f2[i, i] == f1[i, i] for i in range(3))
492+
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
493+
# Function input
494+
else:
495+
assert all(f2[i, j] == 0 for i in range(3) for j in range(3) if i != j)
496+
assert all(f2[i, i] == f1 for i in range(3))

0 commit comments

Comments
 (0)