Skip to content

Commit 7a6b725

Browse files
authored
Merge pull request #2578 from devitocodes/issue_2577
compiler: Fix issue 2577 - edit printer for unevaluation Mul
2 parents 77f8725 + 624bea2 commit 7a6b725

File tree

8 files changed

+68
-20
lines changed

8 files changed

+68
-20
lines changed

.github/workflows/pytest-core-nompi.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ jobs:
172172
173173
- name: Test with pytest
174174
run: |
175-
${{ env.RUN_CMD }} pytest -k "${{ matrix.test-set }}" -m "not parallel" --cov --cov-config=.coveragerc --cov-report=xml ${{ env.TESTS }}
175+
${{ env.RUN_CMD }} pytest -k "${{ matrix.test-set }}" -m "not parallel" --cov --cov-config=.coveragerc --cov-report=xml tests/
176176
177177
- name: Upload coverage to Codecov
178178
if: "!contains(matrix.name, 'docker')"

devito/finite_differences/differentiable.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from devito.finite_differences.tools import make_shift_x0, coeff_priority
1818
from devito.logger import warning
1919
from devito.tools import (as_tuple, filter_ordered, flatten, frozendict,
20-
infer_dtype, is_integer, split)
20+
infer_dtype, is_integer, split, is_number)
2121
from devito.types import Array, DimensionTuple, Evaluable, StencilDimension
2222
from devito.types.basic import AbstractFunction
2323

@@ -548,20 +548,19 @@ def __new__(cls, *args, **kwargs):
548548
nested, others = split(args, lambda e: isinstance(e, Mul))
549549
args = flatten(e.args for e in nested) + list(others)
550550

551+
# Gather all numbers and simplify
552+
nums, others = split(args, lambda e: is_number(e))
553+
scalar = sympy.Mul(*nums)
554+
551555
# a*0 -> 0
552-
if any(i == 0 for i in args):
556+
if scalar == 0:
553557
return sympy.S.Zero
554558

555559
# a*1 -> a
556-
args = [i for i in args if i != 1]
557-
558-
# a*-1 -> a*-1
559-
# a*-1*-1 -> a
560-
# a*-1*-1*-1 -> a*-1
561-
nminus = len([i for i in args if i == sympy.S.NegativeOne])
562-
args = [i for i in args if i != sympy.S.NegativeOne]
563-
if nminus % 2 == 1:
564-
args.append(sympy.S.NegativeOne)
560+
if scalar - 1 == 0:
561+
args = others
562+
else:
563+
args = [scalar] + others
565564

566565
# Reorder for homogeneity with pure SymPy types
567566
_mulsort(args)

devito/symbolics/extended_sympy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,10 @@ def __new__(cls, name, arguments=None, template=None, **kwargs):
622622

623623
return obj
624624

625+
def _eval_is_commutative(self):
626+
# DefFunction defaults to commutative
627+
return True
628+
625629
@property
626630
def name(self):
627631
return self._name

devito/tools/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
'roundm', 'powerset', 'invert', 'flatten', 'single_or', 'filter_ordered',
1313
'as_mapper', 'filter_sorted', 'pprint', 'sweep', 'all_equal', 'as_list',
1414
'indices_to_slices', 'indices_to_sections', 'transitive_closure',
15-
'humanbytes', 'contains_val', 'sorted_priority', 'as_set']
15+
'humanbytes', 'contains_val', 'sorted_priority', 'as_set', 'is_number']
1616

1717

1818
def prod(iterable, initial=1):
@@ -82,6 +82,13 @@ def is_integer(value):
8282
return isinstance(value, (int, np.integer, sympy.Integer))
8383

8484

85+
def is_number(value):
86+
"""
87+
A thorough instance comparison for all number types.
88+
"""
89+
return isinstance(value, (int, float, np.number, sympy.Number))
90+
91+
8592
def contains_val(val, items):
8693
try:
8794
return val in items

examples/seismic/tutorials/05_staggered_acoustic.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
"$\\displaystyle \\left[\\begin{matrix}v_x(t + dt, x + h_x/2, z)\\\\v_z(t + dt, x, z + h_z/2)\\end{matrix}\\right] = \\left[\\begin{matrix}dt \\left(1.0 \\frac{\\partial}{\\partial x} p(t, x, z) + \\frac{v_x(t, x + h_x/2, z)}{dt}\\right)\\\\dt \\left(1.0 \\frac{\\partial}{\\partial z} p(t, x, z) + \\frac{v_z(t, x, z + h_z/2)}{dt}\\right)\\end{matrix}\\right]$"
131131
],
132132
"text/plain": [
133-
"Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(1.0*Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(1.0*Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))"
133+
"Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))"
134134
]
135135
},
136136
"execution_count": 7,

tests/test_dse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_pow_to_mul(expr, expected):
8888

8989

9090
@pytest.mark.parametrize('expr,expected', [
91-
('s - SizeOf("int")*fa[x]', 's - fa[x]*sizeof(int)'),
91+
('s - SizeOf("int")*fa[x]', 's - sizeof(int)*fa[x]'),
9292
('foo(4*fa[x] + 4*fb[x])', 'foo(4*(fa[x] + fb[x]))'),
9393
('floor(0.1*a + 0.1*fa[x])', 'floor(0.1*(a + fa[x]))'),
9494
('floor(0.1*(a + fa[x]))', 'floor(0.1*(a + fa[x]))'),

tests/test_symbolic_coefficients.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ def test_staggered_equation(self):
202202

203203
eq_f = Eq(f, f.dx2(weights=weights))
204204

205-
expected = 'Eq(f(x + h_x/2), 1.0*f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\
206-
'+ 1.0*f(x + 3*h_x/2)/h_x**2)'
205+
expected = 'Eq(f(x + h_x/2), f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\
206+
'+ f(x + 3*h_x/2)/h_x**2)'
207207
assert(str(eq_f.evaluate) == expected)
208208

209209
@pytest.mark.parametrize('stagger', [True, False])

tests/test_symbolics.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
from sympy import Expr, Number, Symbol
88
from devito import (Constant, Dimension, Grid, Function, solve, TimeFunction, Eq, # noqa
99
Operator, SubDimension, norm, Le, Ge, Gt, Lt, Abs, sin, cos,
10-
Min, Max)
11-
from devito.finite_differences.differentiable import SafeInv, Weights
10+
Min, Max, SubDomain)
11+
from devito.finite_differences.differentiable import SafeInv, Weights, Mul
1212
from devito.ir import Expression, FindNodes, ccode
1313
from devito.symbolics import (retrieve_functions, retrieve_indexed, evalrel, # noqa
1414
CallFromPointer, Cast, DefFunction, FieldFromPointer,
1515
INT, FieldFromComposite, IntDiv, Namespace, Rvalue,
1616
ReservedWord, ListInitializer, uxreplace, pow_to_mul,
17-
retrieve_derivatives, BaseCast)
17+
retrieve_derivatives, BaseCast, SizeOf)
1818
from devito.tools import as_tuple
1919
from devito.types import (Array, Bundle, FIndexed, LocalObject, Object,
2020
ComponentAccess, StencilDimension, Symbol as dSymbol)
@@ -874,3 +874,41 @@ def test_assumptions(self, op, expr, assumptions, expected):
874874
assumptions = eval(assumptions)
875875
expected = eval(expected)
876876
assert evalrel(op, eqn, assumptions) == expected
877+
878+
879+
def test_issue_2577a():
880+
u = TimeFunction(name='u', grid=Grid((2,)))
881+
x = u.grid.dimensions[0]
882+
expr = Mul(-1, -1., x, u)
883+
assert expr.args == (x, u)
884+
eq = Eq(u.forward, expr)
885+
op = Operator(eq)
886+
887+
assert '--' not in str(op.ccode)
888+
889+
890+
def test_issue_2577b():
891+
class SD0(SubDomain):
892+
name = 'sd0'
893+
894+
def define(self, dimensions):
895+
x, = dimensions
896+
return {x: ('middle', 1, 1)}
897+
898+
grid = Grid(shape=(11,))
899+
900+
sd0 = SD0(grid=grid)
901+
902+
u = Function(name='u', grid=grid, space_order=2)
903+
904+
eq_u = Eq(u, -(u*u).dxc, subdomain=sd0)
905+
906+
op = Operator(eq_u)
907+
assert '--' not in str(op.ccode)
908+
909+
910+
def test_print_div():
911+
a = SizeOf(np.int32)
912+
b = SizeOf(np.int64)
913+
cstr = ccode(a / b)
914+
assert cstr == 'sizeof(int)/sizeof(long)'

0 commit comments

Comments
 (0)