diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index 5058e1bcc..8bf613994 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -16,10 +16,12 @@ from brian2.core.spikesource import SpikeSource from brian2.core.variables import (Variables, LinkedVariable, DynamicArrayVariable, Subexpression) +from brian2.core.namespace import get_local_namespace from brian2.equations.equations import (Equations, DIFFERENTIAL_EQUATION, SUBEXPRESSION, PARAMETER, check_subexpressions, - extract_constant_subexpressions) + extract_constant_subexpressions, + SingleEquation) from brian2.equations.refractory import add_refractoriness from brian2.parsing.expressions import (parse_expression_dimensions, is_boolean_expression) @@ -31,10 +33,16 @@ fail_for_dimension_mismatch) from brian2.utils.logger import get_logger from brian2.utils.stringtools import get_identifiers - +from brian2.codegen.runtime.numpy_rt.numpy_rt import NumpyCodeObject from .group import Group, CodeRunner, get_dtype from .subgroup import Subgroup +try: + from scipy.optimize import root + scipy_available = True +except ImportError: + scipy_available = False + __all__ = ['NeuronGroup'] logger = get_logger(__name__) @@ -920,3 +928,94 @@ def add_event_to_text(event): add_event_to_text(event) return '\n'.join(text) + + def resting_state(self, x0 = {}): + ''' + Calculate resting state of the system. + + Parameters + ---------- + x0 : dict + Initial guess for the state variables. If any of the system's state variables are not + added, default value of 0 is mapped as the initial guess to the missing state variables. + Note: Time elapsed to locate the resting state would be lesser for better initial guesses. + + Returns + ------- + rest_state : dict + Dictioary with pair of state variables and resting state values. Returned values + are represented in SI units. + ''' + # check scipy availability + if scipy_available == False: + raise NotImplementedError("Scipy is not available for using `scipy.optimize.root()`") + # check state variables defined in initial guess are valid + if(x0.keys() - self.equations.diff_eq_names): + raise KeyError("Unknown State Variable: {}".format(next(iter(x0.keys() - + self.equations.diff_eq_names)))) + + # Add 0 as the intial value for non-mentioned state variables in x0 + x0.update({name : 0 for name in self.equations.diff_eq_names - x0.keys()}) + sorted_variable_values = list(dict(sorted(x0.items())).values()) + result = root(_wrapper, sorted_variable_values, args = (self.equations, get_local_namespace(1))) + # check the result message for the status of convergence + if result.success == False: + raise Exception("The model failed to converge at a resting state. Trying better initial guess shall fix the problem") + return dict(zip(sorted(self.equations.diff_eq_names), result.x)) + +def _evaluate_rhs(eqs, values, namespace=None): + """ + Evaluates the RHS of a system of differential equations for given state + variable values. External constants can be provided via the namespace or + will be taken from the local namespace. + This function could be used for example to find a resting state of the + system, i.e. a fixed point where the RHS of all equations are approximately + 0. + Parameters + ---------- + eqs : `Equations` + The equations + values : dict-like + Values for each of the state variables (differential equations and + parameters). + Returns + ------- + rhs : dict + A dictionary with the names of all variables defined by differential + equations as keys and the respective RHS of the equations as values. + """ + # Make a new set of equations, where differential equations are replaced + # by parameters, and a new subexpression defines their RHS. + # E.g. for 'dv/dt = -v / tau : volt' use: + # '''v : volt + # RHS_v = -v / tau : volt''' + new_equations = [] + for eq in eqs.values(): + if eq.type == DIFFERENTIAL_EQUATION: + new_equations.append(SingleEquation(PARAMETER, eq.varname, + dimensions=eq.dim, + var_type=eq.var_type)) + new_equations.append(SingleEquation(SUBEXPRESSION, 'RHS_'+eq.varname, + dimensions=eq.dim/second.dim, + var_type=eq.var_type, + expr=eq.expr)) + else: + new_equations.append(eq) + # TODO: Hide this from standalone mode + group = NeuronGroup(1, model=Equations(new_equations), + codeobj_class=NumpyCodeObject, + namespace=namespace) + + # Set the values of the state variables/parameters and units are not taken into account + group.set_states(values, units = False) + + # Get the values of all RHS_... subexpressions + states = ['RHS_' + name for name in eqs.diff_eq_names] + return group.get_states(states) + +def _wrapper(args, equations, namespace): + """ + Function for which root needs to be calculated. Callable function of scipy.optimize.root() + """ + rhs = _evaluate_rhs(equations, {name : arg for name, arg in zip(sorted(equations.diff_eq_names), args)}, namespace) + return [float(rhs['RHS_{}'.format(name)]) for name in sorted(equations.diff_eq_names)] diff --git a/brian2/tests/test_neurongroup.py b/brian2/tests/test_neurongroup.py index 85bacf38b..b6923e766 100644 --- a/brian2/tests/test_neurongroup.py +++ b/brian2/tests/test_neurongroup.py @@ -23,8 +23,9 @@ from brian2.units.allunits import second, volt from brian2.units.fundamentalunits import (DimensionMismatchError, have_same_dimensions) -from brian2.units.stdunits import ms, mV, Hz +from brian2.units.stdunits import ms, mV, Hz, cm, msiemens, nA from brian2.units.unitsafefunctions import linspace +from brian2.units.allunits import second, volt, umetre, siemens, ufarad from brian2.utils.logger import catch_logs @@ -1716,6 +1717,57 @@ def test_semantics_mod(): assert_allclose(G.x[:], float_values % 3) assert_allclose(G.y[:], float_values % 3) +def test_simple_resting_value(): + """ + Test the resting state values of the system + """ + # simple model with single dependent variable, here it is not necessary + # to run the model as the resting value is certain + El = - 100 + tau = 1 * ms + eqs = ''' + dv/dt = (El - v)/tau : 1 + ''' + grp = NeuronGroup(1, eqs, method = 'exact') + resting_state = grp.resting_state() + assert_allclose(resting_state['v'], El) + + # one more example + area = 100 * umetre ** 2 + g_L = 1e-2 * siemens * cm ** -2 * area + E_L = 1000 + Cm = 1 * ufarad * cm ** -2 * area + grp = NeuronGroup(10, '''dv/dt = I_leak / Cm : volt + I_leak = g_L*(E_L - v) : amp''') + resting_state = grp.resting_state({'v': float(10000)}) + assert_allclose(resting_state['v'], E_L) + +def test_failed_resting_state(): + # check the failed to converge system is correctly notified to the user + area = 20000 * umetre ** 2 + Cm = 1 * ufarad * cm ** -2 * area + gl = 5e-5 * siemens * cm ** -2 * area + El = -65 * mV + EK = -90 * mV + ENa = 50 * mV + g_na = 100 * msiemens * cm ** -2 * area + g_kd = 30 * msiemens * cm ** -2 * area + VT = -63 * mV + I = 0.01*nA + eqs = Equations(''' + dv/dt = (gl*(El-v) - g_na*(m*m*m)*h*(v-ENa) - g_kd*(n*n*n*n)*(v-EK) + I)/Cm : volt + dm/dt = 0.32*(mV**-1)*(13.*mV-v+VT)/ + (exp((13.*mV-v+VT)/(4.*mV))-1.)/ms*(1-m)-0.28*(mV**-1)*(v-VT-40.*mV)/ + (exp((v-VT-40.*mV)/(5.*mV))-1.)/ms*m : 1 + dn/dt = 0.032*(mV**-1)*(15.*mV-v+VT)/ + (exp((15.*mV-v+VT)/(5.*mV))-1.)/ms*(1.-n)-.5*exp((10.*mV-v+VT)/(40.*mV))/ms*n : 1 + dh/dt = 0.128*exp((17.*mV-v+VT)/(18.*mV))/ms*(1.-h)-4./(1+exp((40.*mV-v+VT)/(5.*mV)))/ms*h : 1 + ''') + group = NeuronGroup(1, eqs, method='exponential_euler') + group.v = -70*mV + # very poor choice of initial values causing the convergence to fail + with pytest.raises(Exception): + group.resting_state({'v': 0, 'm': 100000000, 'n': 1000000, 'h': 100000000}) if __name__ == '__main__': test_set_states() @@ -1792,3 +1844,5 @@ def test_semantics_mod(): test_semantics_floor_division() test_semantics_floating_point_division() test_semantics_mod() + test_simple_resting_value() + test_failed_resting_state() \ No newline at end of file