diff --git a/brian2/equations/codestrings.py b/brian2/equations/codestrings.py index 39fc07a03..7e6f1644e 100644 --- a/brian2/equations/codestrings.py +++ b/brian2/equations/codestrings.py @@ -3,13 +3,18 @@ information about its namespace. Only serves as a parent class, its subclasses `Expression` and `Statements` are the ones that are actually used. ''' - +import re +import string from collections.abc import Hashable +from typing import Sequence +import numbers import sympy +import numpy as np from brian2.utils.logger import get_logger from brian2.utils.stringtools import get_identifiers +from brian2.utils.topsort import topsort from brian2.parsing.sympytools import str_to_sympy, sympy_to_str __all__ = ['Expression', 'Statements'] @@ -35,6 +40,7 @@ def __init__(self, code): # : Set of identifiers in the code string self.identifiers = get_identifiers(code) + self.template_identifiers = get_identifiers(code, template=True) code = property(lambda self: self._code, doc='The code string') @@ -77,6 +83,11 @@ class Statements(CodeString): pass +class Default(dict): + def __missing__(self, key): + return f'{{{key}}}' + + class Expression(CodeString): ''' Class for representing an expression. @@ -100,7 +111,7 @@ def __init__(self, code=None, sympy_expression=None): if code is None: code = sympy_to_str(sympy_expression) - else: + elif '{' not in code: # Just try to convert it to a sympy expression to get syntax errors # for incorrect expressions str_to_sympy(code) @@ -193,6 +204,95 @@ def __ne__(self, other): def __hash__(self): return hash(self.code) + def _do_substitution(self, to_replace, replacement): + # Replacements can be lists, deal with single replacements + # as single-element lists + replaced_name = False + replaced_placeholder = False + if not isinstance(replacement, Sequence) or isinstance(replacement, str): + replacement = [replacement] + replacement_strs = [] + for one_replacement in replacement: + if isinstance(one_replacement, str): + if any(c not in string.ascii_letters + '_{}' + for c in one_replacement): + # Check whether the replacement can be interpreted as an expression + try: + expr = Expression(one_replacement) + replacement_strs.append(expr.code) + except SyntaxError: + raise SyntaxError(f'Replacement \'{one_replacement}\' for' + f'\'{to_replace}\' is neither a name nor a ' + f'valid expression.') + else: + replacement_strs.append(one_replacement) + elif isinstance(one_replacement, (numbers.Number, np.ndarray)): + if not getattr(one_replacement, 'shape', ()) == (): + raise TypeError(f'Cannot replace variable \'{to_replace}\' with an ' + f'array of values.') + replacement_strs.append(repr(one_replacement)) + elif isinstance(one_replacement, Expression): + replacement_strs.append(one_replacement.code) + else: + raise TypeError(f'Cannot replace \'{to_replace}\' with an object of type ' + f'\'{type(one_replacement)}\'.') + + if len(replacement_strs) == 1: + replacement_str = replacement_strs[0] + # Be careful if the string is more than just a name/number + if any(c not in string.ascii_letters + string.digits + '_.{}' + for c in replacement_str): + replacement_str = '(' + replacement_str + ')' + else: + replacement_str = '(' + (' + '.join(replacement_strs)) + ')' + + new_expr = self + if to_replace in new_expr.identifiers: + code = new_expr.code + new_expr = Expression(re.sub(r'(? 0) + stochastic_variables = property(lambda self: {variable for variable in self.identifiers if variable =='xi' or variable.startswith('xi_')}, doc='Stochastic variables in the RHS of this equation') + template = property(lambda self: self.varname.startswith('{') or + (self.expr is not None and '{' in self.expr.code)) + def __eq__(self, other): if not isinstance(other, SingleEquation): return NotImplemented @@ -555,7 +570,7 @@ class Equations(Hashable, Mapping): def __init__(self, eqns, **kwds): if isinstance(eqns, str): - self._equations = parse_string_equations(eqns) + self._equations = parse_string_equations(eqns, template=True) # Do a basic check for the identifiers self.check_identifiers() else: @@ -569,6 +584,7 @@ def __init__(self, eqns, **kwds): eq.varname) self._equations[eq.varname] = eq + self._orig_equations = self._equations self._equations = self._substitute(kwds) # Check for special symbol xi (stochastic term) @@ -593,62 +609,149 @@ def __init__(self, eqns, **kwds): #: Cache for equations with the subexpressions substituted self._substituted_expressions = None - def _substitute(self, replacements): - if len(replacements) == 0: - return self._equations - + def _do_substitution(self, equations, to_replace, replacement): + # Replacements can be lists, deal with single replacements + # as single-element lists new_equations = {} - for eq in self.values(): - # Replace the name of a model variable (works only for strings) - if eq.varname in replacements: - new_varname = replacements[eq.varname] - if not isinstance(new_varname, str): - raise ValueError(('Cannot replace model variable "%s" ' - 'with a value') % eq.varname) - if new_varname in self or new_varname in new_equations: - raise EquationError( - ('Cannot replace model variable "%s" ' - 'with "%s", duplicate definition ' - 'of "%s".' % (eq.varname, new_varname, - new_varname))) - # make sure that the replacement is a valid identifier - Equations.check_identifier(new_varname) + additional_equations = {} + replaced_name = False + replaced_placeholder = False + if not isinstance(replacement, Sequence) or isinstance(replacement, str): + replacement = [replacement] + replacement_strs = [] + for one_replacement in replacement: + if isinstance(one_replacement, str): + if any(c not in string.ascii_letters + '_{}' + for c in one_replacement): + # Check whether the replacement can be interpreted as an expression + try: + expr = Expression(one_replacement) + replacement_strs.append(expr.code) + except SyntaxError: + raise SyntaxError(f'Replacement \'{one_replacement}\' for' + f'\'{to_replace}\' is neither a name nor a ' + f'valid expression.') + else: + replacement_strs.append(one_replacement) + elif isinstance(one_replacement, (numbers.Number, Quantity)): + if not getattr(one_replacement, 'shape', ()) == (): + raise TypeError(f'Cannot replace variable \'{to_replace}\' with an ' + f'array of values.') + replacement_strs.append(repr(one_replacement)) + elif isinstance(one_replacement, Expression): + replacement_strs.append(one_replacement.code) + elif isinstance(one_replacement, Equations): + replacement_strs.append(list(one_replacement.keys())[0]) # name of first equation + for additional_eq in one_replacement: + if additional_eq in equations or additional_eq in additional_equations: + raise SyntaxError(f'Adding equations to replace \'{to_replace}\' leads ' + f'to duplicated definition of variable \'{additional_eq}\'') + additional_equations[additional_eq] = one_replacement[additional_eq] else: - new_varname = eq.varname - - if eq.type in [SUBEXPRESSION, DIFFERENTIAL_EQUATION]: - # Replace values in the RHS of the equation - new_code = eq.expr.code - for to_replace, replacement in replacements.items(): - if to_replace in eq.identifiers: - if isinstance(replacement, str): - # replace the name with another name - new_code = re.sub('\\b' + to_replace + '\\b', - replacement, new_code) - else: - # replace the name with a value - new_code = re.sub('\\b' + to_replace + '\\b', - '(' + repr(replacement) + ')', - new_code) - try: - Expression(new_code) - except ValueError as ex: - raise ValueError( - ('Replacing "%s" with "%r" failed: %s') % - (to_replace, replacement, ex)) - new_equations[new_varname] = SingleEquation(eq.type, new_varname, - dimensions=eq.dim, - var_type=eq.var_type, - expr=Expression(new_code), - flags=eq.flags) + raise TypeError(f'Cannot replace \'{to_replace}\' with an object of type ' + f'\'{type(one_replacement)}\'.') + if len(replacement_strs) == 1: + replacement_str = replacement_strs[0] + # Be careful if the string is more than just a name/number + if any(c not in string.ascii_letters + string.digits + '_.{}' + for c in replacement_str): + replacement_str = '(' + replacement_str + ')' + else: + replacement_str = '(' + (' + '.join(replacement_strs)) + ')' + + for eq in equations.values(): + # Check whether the variable name itself (or part of it) will be replaced + if eq.varname == to_replace: + if not len(replacement) == 1: + raise TypeError(f'Cannot replace variable name \'{to_replace}\' with ' + f'a list of values.') + if not isinstance(replacement[0], str): + raise TypeError(f'Cannot replace variable name \'{to_replace}\' with ' + f'an object of type \'{type(replacement[0])}\'.') + new_varname = replacement[0] + replaced_name = True + elif '{' + to_replace + '}' in eq.varname: + if not len(replacement) == 1: + raise TypeError(f'Cannot replace \'{{{to_replace}}}\' as a part of a variable' + f'name with a list of values.') + if not isinstance(replacement[0], str): + raise TypeError(f'Cannot replace \'{{{to_replace}}}\' as a part of a variable' + f'name with an object of type \'{type(replacement[0])}\'.') + new_varname = eq.varname.replace('{' + to_replace + '}', replacement[0]) + replaced_placeholder = True else: - new_equations[new_varname] = SingleEquation(eq.type, new_varname, - dimensions=eq.dim, - var_type=eq.var_type, - flags=eq.flags) + new_varname = eq.varname + # Check whether the new variable name is still valid + if '{' not in new_varname: + Equations.check_identifier(new_varname) + if new_varname != eq.varname: + if new_varname in equations: + raise EquationError(f'Cannot replace \'{eq.varname}\' by \'{new_varname}\', ' \ + 'this name is already used for another variable.') + # Replace occurrences in the RHS of equations + new_expr = eq.expr + if to_replace in eq.identifiers and eq.expr is not None: + code = eq.expr.code + new_expr = Expression(re.sub(r'(? 0: + identifiers = [f'\'{identifier}\'' + for identifier in model.template_identifiers] + identifier_list = ', '.join(sorted(identifiers)) + raise TypeError('The model equations contain placeholders, substitute ' + 'names/values for the ' f'following before passing ' + f'them to {self.__class__.__name__}: {identifier_list}.') # Check flags model.check_flags({DIFFERENTIAL_EQUATION: ('unless refractory',), diff --git a/brian2/only.py b/brian2/only.py index f73167cc3..cf374d76a 100644 --- a/brian2/only.py +++ b/brian2/only.py @@ -81,7 +81,8 @@ def restore_initial_state(): 'DEFAULT_FUNCTIONS', 'Function', 'implementation', 'declare_types', 'PreferenceError', 'BrianPreference', 'prefs', 'brian_prefs', 'Clock', 'defaultclock', - 'Equations', 'Expression', 'Statements', + 'Equations', 'Expression', + 'Statements', 'BrianObject', 'BrianObjectException', 'Network', 'profiling_summary', 'scheduling_summary', diff --git a/brian2/synapses/synapses.py b/brian2/synapses/synapses.py index 17f507b32..b57e8a020 100644 --- a/brian2/synapses/synapses.py +++ b/brian2/synapses/synapses.py @@ -740,6 +740,14 @@ def __init__(self, source, target=None, model=None, on_pre=None, raise TypeError(('model has to be a string or an Equations ' 'object, is "%s" instead.') % type(model)) + if len(model.template_identifiers) > 0: + identifiers = [f'\'{identifier}\'' + for identifier in model.template_identifiers] + identifier_list = ', '.join(sorted(identifiers)) + raise TypeError('The model equations contain placeholders, substitute ' + 'names/values for the ' f'following before passing ' + f'them to {self.__class__.__name__}: {identifier_list}.') + # Check flags model.check_flags({DIFFERENTIAL_EQUATION: ['event-driven', 'clock-driven'], SUBEXPRESSION: ['summed', 'shared', diff --git a/brian2/tests/test_equations.py b/brian2/tests/test_equations.py index 0c614750c..814a806b4 100644 --- a/brian2/tests/test_equations.py +++ b/brian2/tests/test_equations.py @@ -221,13 +221,17 @@ def test_wrong_replacements(): ''', v='x') # Replacing a model variable name with a value - with pytest.raises(ValueError): + with pytest.raises(TypeError): Equations('dv/dt = -v / tau : 1', v=3 * mV) # Replacing with an illegal value - with pytest.raises(SyntaxError): + with pytest.raises(TypeError): Equations('dv/dt = -v/tau : 1', tau=np.arange(5)) + # Replacing something that does not exist + with pytest.raises(KeyError): + Equations('dv/dt = -v/tau : 1', x='y') + @pytest.mark.codegen_independent def test_substitute(): diff --git a/brian2/utils/stringtools.py b/brian2/utils/stringtools.py index c37239280..ccf17922f 100644 --- a/brian2/utils/stringtools.py +++ b/brian2/utils/stringtools.py @@ -152,11 +152,12 @@ def replace(s, substitutions): KEYWORDS = {'and', 'or', 'not', 'True', 'False'} -def get_identifiers(expr, include_numbers=False): +def get_identifiers(expr, include_numbers=False, template=False): ''' Return all the identifiers in a given string ``expr``, that is everything - that matches a programming language variable like expression, which is - here implemented as the regexp ``\\b[A-Za-z_][A-Za-z0-9_]*\\b``. + that matches a programming language variable like expression. Placeholder + arguments of the form ``{name}`` are handled separately and only returned + if the ``template`` argument is ``True``. Parameters ---------- @@ -164,7 +165,9 @@ def get_identifiers(expr, include_numbers=False): The string to analyze include_numbers : bool, optional Whether to include number literals in the output. Defaults to ``False``. - + template : bool, optional + Whether to only return template identifiers, i.e. identifiers enclosed + by curly braces. Defaults to ``False`` Returns ------- identifiers : set @@ -174,20 +177,30 @@ def get_identifiers(expr, include_numbers=False): -------- >>> expr = '3-a*_b+c5+8+f(A - .3e-10, tau_2)*17' >>> ids = get_identifiers(expr) - >>> print(sorted(list(ids))) + >>> print(sorted(ids)) ['A', '_b', 'a', 'c5', 'f', 'tau_2'] >>> ids = get_identifiers(expr, include_numbers=True) - >>> print(sorted(list(ids))) + >>> print(sorted(ids)) ['.3e-10', '17', '3', '8', 'A', '_b', 'a', 'c5', 'f', 'tau_2'] + >>> template_expr = '{name}_{suffix} = a*{name}_{suffix} + b' + >>> template_ids = get_identifiers(template_expr, template=True) + >>> print(sorted(template_ids)) + ['name', 'suffix'] ''' - identifiers = set(re.findall(r'\b[A-Za-z_][A-Za-z0-9_]*\b', expr)) - if include_numbers: - # only the number, not a + or - - numbers = set(re.findall(r'(?<=[^A-Za-z_])[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?|^[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?', - expr)) + if include_numbers and template: + raise ValueError('Cannot combine the \'template\' and \'include_numbers\' arguments.') + + if template: + return set(re.findall(r'(?:[{])([A-Za-z_][A-Za-z0-9_]*)(?:[}])', expr)) else: - numbers = set() - return (identifiers - KEYWORDS) | numbers + if include_numbers: + # only the number, not a + or - + number_regexp = r'(?<=[^A-Za-z_])[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?|^[0-9]*\.?[0-9]+(?:[eE][-+]?[0-9]+)?' + numbers = set(re.findall(number_regexp, expr)) + else: + numbers = set() + identifiers = set(re.findall(r'(?