Skip to content

Commit 24422fa

Browse files
committed
api: fix non-integer Mul args
1 parent 17d68fe commit 24422fa

File tree

3 files changed

+4
-8
lines changed

3 files changed

+4
-8
lines changed

devito/finite_differences/differentiable.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -552,17 +552,13 @@ def __new__(cls, *args, **kwargs):
552552
nums, others = split(args, lambda e: isinstance(e, (int, float,
553553
sympy.Number, np.number)))
554554
scalar = sympy.Mul(*nums)
555-
try:
556-
scalar = sympy.Integer(scalar)
557-
except TypeError:
558-
pass
559555

560556
# a*0 -> 0
561557
if scalar == 0:
562558
return sympy.S.Zero
563559

564560
# a*1 -> a
565-
if scalar == 1:
561+
if scalar - 1 == 0:
566562
args = others
567563
else:
568564
args = [scalar] + others

examples/seismic/tutorials/05_staggered_acoustic.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@
130130
"$\\displaystyle \\left[\\begin{matrix}v_x(t + dt, x + h_x/2, z)\\\\v_z(t + dt, x, z + h_z/2)\\end{matrix}\\right] = \\left[\\begin{matrix}dt \\left(1.0 \\frac{\\partial}{\\partial x} p(t, x, z) + \\frac{v_x(t, x + h_x/2, z)}{dt}\\right)\\\\dt \\left(1.0 \\frac{\\partial}{\\partial z} p(t, x, z) + \\frac{v_z(t, x, z + h_z/2)}{dt}\\right)\\end{matrix}\\right]$"
131131
],
132132
"text/plain": [
133-
"Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(1.0*Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(1.0*Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))"
133+
"Eq(Vector(v_x(t + dt, x + h_x/2, z), v_z(t + dt, x, z + h_z/2)), Vector(dt*(Derivative(p(t, x, z), x) + v_x(t, x + h_x/2, z)/dt), dt*(Derivative(p(t, x, z), z) + v_z(t, x, z + h_z/2)/dt)))"
134134
]
135135
},
136136
"execution_count": 7,

tests/test_symbolic_coefficients.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,8 @@ def test_staggered_equation(self):
202202

203203
eq_f = Eq(f, f.dx2(weights=weights))
204204

205-
expected = 'Eq(f(x + h_x/2), 1.0*f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\
206-
'+ 1.0*f(x + 3*h_x/2)/h_x**2)'
205+
expected = 'Eq(f(x + h_x/2), f(x - h_x/2)/h_x**2 - 2.0*f(x + h_x/2)/h_x**2 '\
206+
'+ f(x + 3*h_x/2)/h_x**2)'
207207
assert(str(eq_f.evaluate) == expected)
208208

209209
@pytest.mark.parametrize('stagger', [True, False])

0 commit comments

Comments
 (0)