Skip to content

Commit 0d6259c

Browse files
committed
api: fix Mul arguments processing
1 parent be099e8 commit 0d6259c

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

devito/finite_differences/differentiable.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -548,20 +548,24 @@ def __new__(cls, *args, **kwargs):
548548
nested, others = split(args, lambda e: isinstance(e, Mul))
549549
args = flatten(e.args for e in nested) + list(others)
550550

551+
# Gather all numbers and simplify
552+
nums, others = split(args, lambda e: isinstance(e, (int, float,
553+
sympy.Number, np.number)))
554+
scalar = sympy.Mul(*nums)
555+
try:
556+
scalar = sympy.Integer(scalar)
557+
except TypeError:
558+
pass
559+
551560
# a*0 -> 0
552-
if any(i == 0 for i in args):
561+
if scalar == 0:
553562
return sympy.S.Zero
554563

555564
# a*1 -> a
556-
args = [i for i in args if i != 1]
557-
558-
# a*-1 -> a*-1
559-
# a*-1*-1 -> a
560-
# a*-1*-1*-1 -> a*-1
561-
nminus = len([i for i in args if i == sympy.S.NegativeOne])
562-
args = [i for i in args if i != sympy.S.NegativeOne]
563-
if nminus % 2 == 1:
564-
args.append(sympy.S.NegativeOne)
565+
if scalar == 1:
566+
args = others
567+
else:
568+
args = [scalar] + others
565569

566570
# Reorder for homogeneity with pure SymPy types
567571
_mulsort(args)

tests/test_symbolics.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
1010
Min, Max, SubDomain)
11-
from devito.finite_differences.differentiable import SafeInv, Weights
11+
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
1414
CallFromPointer, Cast, DefFunction, FieldFromPointer,
1515
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
1616
ReservedWord, ListInitializer, uxreplace, pow_to_mul,
1717
retrieve_derivatives, BaseCast)
18-
from devito.symbolics.unevaluation import Mul as UnevalMul
1918
from devito.tools import as_tuple
2019
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
2120
ComponentAccess, StencilDimension, Symbol as dSymbol)
@@ -877,16 +876,19 @@ def test_assumptions(self, op, expr, assumptions, expected):
877876
assert evalrel(op, eqn, assumptions) == expected
878877

879878

880-
def test_issue_2577():
879+
def test_issue_2577a():
881880

882881
u = TimeFunction(name='u', grid=Grid((2,)))
883-
eq = Eq(u.forward, UnevalMul(-1, -1., u))
882+
x = u.grid.dimensions[0]
883+
expr = Mul(-1, -1., x, u)
884+
assert expr.args == (x, u)
885+
eq = Eq(u.forward, expr)
884886
op = Operator(eq)
885887

886888
assert '--' not in str(op.ccode)
887889

888890

889-
def test_issue_2577a():
891+
def test_issue_2577b():
890892
class SD0(SubDomain):
891893
name = 'sd0'
892894

0 commit comments

Comments
 (0)