26
26
from devito .types .args import ArgProvider
27
27
from devito .types .caching import CacheManager
28
28
from devito .types .basic import AbstractFunction , Size
29
- from devito .types .utils import Buffer , DimensionTuple , NODE , CELL , host_layer
29
+ from devito .types .utils import Buffer , DimensionTuple , NODE , CELL , host_layer , Staggering
30
30
31
31
__all__ = ['Function' , 'TimeFunction' , 'SubFunction' , 'TempFunction' ]
32
32
@@ -1010,6 +1010,10 @@ def _cache_meta(self):
1010
1010
def __init_finalize__ (self , * args , ** kwargs ):
1011
1011
super ().__init_finalize__ (* args , ** kwargs )
1012
1012
1013
+ # Staggering
1014
+ self ._staggered = self .__staggered_setup__ (self .dimensions ,
1015
+ staggered = kwargs .get ('staggered' ))
1016
+
1013
1017
# Space order
1014
1018
space_order = kwargs .get ('space_order' , 1 )
1015
1019
if isinstance (space_order , int ):
@@ -1042,7 +1046,7 @@ def __fd_setup__(self):
1042
1046
1043
1047
@cached_property
1044
1048
def _fd_priority (self ):
1045
- return 1 if self .staggered in [ NODE , None ] else 2
1049
+ return 1 if self .staggered . on_node else 2
1046
1050
1047
1051
@property
1048
1052
def is_parameter (self ):
@@ -1059,26 +1063,33 @@ def _eval_at(self, func):
1059
1063
return self
1060
1064
1061
1065
@classmethod
1062
- def __staggered_setup__ (cls , dimensions , ** kwargs ):
1066
+ def __staggered_setup__ (cls , dimensions , staggered = None , ** kwargs ):
1063
1067
"""
1064
1068
Setup staggering-related metadata. This method assigns:
1065
1069
1066
1070
* 0 to non-staggered dimensions;
1067
1071
* 1 to staggered dimensions.
1068
1072
"""
1069
- stagg = kwargs .get ('staggered' , None )
1070
- if stagg is CELL :
1071
- staggered = (sympy .S .One for d in dimensions )
1072
- elif stagg in [None , NODE ]:
1073
- staggered = (sympy .S .Zero for d in dimensions )
1074
- elif all (is_integer (s ) for s in as_tuple (stagg )):
1073
+ if not staggered :
1074
+ processed = ()
1075
+ elif staggered is CELL :
1076
+ processed = (sympy .S .One ,)* len (dimensions )
1077
+ elif staggered is NODE :
1078
+ processed = (sympy .S .Zero ,)* len (dimensions )
1079
+ elif all (is_integer (s ) for s in as_tuple (staggered )):
1075
1080
# Staggering is already a tuple likely from rebuild
1076
- assert len (stagg ) == len (dimensions )
1077
- return tuple ( stagg )
1081
+ assert len (staggered ) == len (dimensions )
1082
+ processed = staggered
1078
1083
else :
1079
- staggered = (sympy .S .One if d in as_tuple (stagg ) else sympy .S .Zero
1080
- for d in dimensions )
1081
- return tuple (staggered )
1084
+ processed = []
1085
+ for d in dimensions :
1086
+ if d in as_tuple (staggered ):
1087
+ processed .append (sympy .S .One )
1088
+ elif - d in as_tuple (staggered ):
1089
+ processed .append (sympy .S .NegativeOne )
1090
+ else :
1091
+ processed .append (sympy .S .Zero )
1092
+ return tuple (processed )
1082
1093
1083
1094
@classmethod
1084
1095
def __indices_setup__ (cls , * args , ** kwargs ):
@@ -1097,14 +1108,27 @@ def __indices_setup__(cls, *args, **kwargs):
1097
1108
assert len (args ) == len (dimensions )
1098
1109
staggered_indices = tuple (args )
1099
1110
else :
1100
- # Staggered indices
1101
- staggered_indices = (d + i * d .spacing / 2
1102
- for d , i in zip (dimensions , staggered ))
1103
- return tuple (dimensions ), tuple (staggered_indices ), staggered
1111
+ if not staggered :
1112
+ staggered_indices = (d for d in dimensions )
1113
+ else :
1114
+ # Staggered indices
1115
+ staggered_indices = (d + i * d .spacing / 2
1116
+ for d , i in zip (dimensions , staggered ))
1117
+ return tuple (dimensions ), tuple (staggered_indices )
1118
+
1119
+ @property
1120
+ def staggered (self ):
1121
+ """The staggered indices of the object."""
1122
+ if self ._staggered :
1123
+ return Staggering (* self ._staggered , getters = self .dimensions )
1124
+ else :
1125
+ return Staggering (getters = self .dimensions )
1104
1126
1105
1127
@property
1106
1128
def is_Staggered (self ):
1107
- return self .staggered is not None
1129
+ if not self .staggered :
1130
+ return False
1131
+ return True
1108
1132
1109
1133
@classmethod
1110
1134
def __shape_setup__ (cls , ** kwargs ):
@@ -1392,7 +1416,6 @@ def __fd_setup__(self):
1392
1416
@classmethod
1393
1417
def __indices_setup__ (cls , * args , ** kwargs ):
1394
1418
dimensions = kwargs .get ('dimensions' )
1395
- staggered = kwargs .get ('staggered' )
1396
1419
1397
1420
if dimensions is None :
1398
1421
save = kwargs .get ('save' )
@@ -1407,7 +1430,7 @@ def __indices_setup__(cls, *args, **kwargs):
1407
1430
dimensions .insert (cls ._time_position , time_dim )
1408
1431
1409
1432
return Function .__indices_setup__ (
1410
- * args , dimensions = dimensions , staggered = staggered
1433
+ * args , dimensions = dimensions , staggered = kwargs . get ( ' staggered' )
1411
1434
)
1412
1435
1413
1436
@classmethod
@@ -1446,7 +1469,7 @@ def __shape_setup__(cls, **kwargs):
1446
1469
1447
1470
@cached_property
1448
1471
def _fd_priority (self ):
1449
- return 2.1 if self .staggered in [ NODE , None ] else 2.2
1472
+ return 2.1 if self .staggered . on_node else 2.2
1450
1473
1451
1474
@property
1452
1475
def time_order (self ):
@@ -1600,7 +1623,7 @@ def __indices_setup__(cls, **kwargs):
1600
1623
# Sanity check
1601
1624
assert not any (d .is_NonlinearDerived for d in dimensions )
1602
1625
1603
- return dimensions , dimensions , ( sympy . S . Zero for _ in dimensions )
1626
+ return dimensions , dimensions
1604
1627
1605
1628
def __halo_setup__ (self , ** kwargs ):
1606
1629
pointer_dim = kwargs .get ('pointer_dim' )
0 commit comments