From 8d42360446642a0ffa6a9773aa1d63495bb5fd91 Mon Sep 17 00:00:00 2001 From: Marcel Stimberg Date: Thu, 19 Dec 2019 15:39:32 +0100 Subject: [PATCH] More strict check that equation is conditionally linear --- brian2/stateupdaters/exponential_euler.py | 9 ++++++--- brian2/tests/test_stateupdaters.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/brian2/stateupdaters/exponential_euler.py b/brian2/stateupdaters/exponential_euler.py index f81645a51..4b3bad69a 100644 --- a/brian2/stateupdaters/exponential_euler.py +++ b/brian2/stateupdaters/exponential_euler.py @@ -56,13 +56,16 @@ def get_conditionally_linear_system(eqs, variables=None): s_expr = sp.collect(s_expr, var, evaluate=False) - if len(s_expr) > 2 or var not in s_expr: + if (len(s_expr) > 2 or + var not in s_expr or + s_expr.get(sp.S.One, sp.S.Zero).has(var) + ): raise ValueError(('The expression "%s", defining the variable %s, ' 'could not be separated into linear components') % (expr, name)) - coefficients[name] = (s_expr[var], s_expr.get(1, 0)) + coefficients[name] = (s_expr[var], s_expr.get(sp.S.One, sp.S.Zero)) else: - coefficients[name] = (0, s_expr) + coefficients[name] = (sp.S.Zero, s_expr) return coefficients diff --git a/brian2/tests/test_stateupdaters.py b/brian2/tests/test_stateupdaters.py index 105e5db44..30f249099 100644 --- a/brian2/tests/test_stateupdaters.py +++ b/brian2/tests/test_stateupdaters.py @@ -398,6 +398,17 @@ def test_priority(): check_integration(eqs, variables, can_integrate) + # Equation that both linearly and non-linearly depends on the variable, + # and is therefore not "conditionally linear". The exponential Euler updater + # should therefore decline to integrate the equations. + param = 1 + eqs = Equations('''dv/dt = (-v + exp(-v) + 1.0)/tau : 1''') + updater(eqs, variables) # should not raise an error + can_integrate = {linear: False, euler: True, exponential_euler: False, + rk2: True, rk4: True, heun: True, milstein: True} + + check_integration(eqs, variables, can_integrate) + # Equations resulting in complex linear solution for older versions of sympy eqs = Equations('''dv/dt = (ge+gi-(v+49*mV))/(20*ms) : volt dge/dt = -ge/(5*ms) : volt