diff --git a/brian2/groups/neurongroup.py b/brian2/groups/neurongroup.py index b8f3e7f6e..78944f6e3 100644 --- a/brian2/groups/neurongroup.py +++ b/brian2/groups/neurongroup.py @@ -518,7 +518,7 @@ def __init__( method_options=None, threshold=None, reset=None, - refractory=False, + refractory=None, # Updated to None instead of False events=None, namespace=None, dtype=None, @@ -576,6 +576,26 @@ def __init__( } ) + # Handle events + if events is None: + events = {} + self.events = {'spike': threshold} if threshold else {} + self.events.update(events) + + # Handle refractory + if refractory is not None: + if isinstance(refractory, (str, Quantity)): + refractory = {'spike': refractory} + elif isinstance(refractory, dict): + for event in refractory: + if event not in self.events: + raise ValueError(f"Unknown event '{event}' in refractory dictionary.") + else: + raise TypeError("refractory must be a string, Quantity, or dictionary") + else: + refractory = {} + self._refractory = refractory + # add refractoriness #: The original equations as specified by the user (i.e. without #: the multiplied `int(not_refractory)` term for equations marked as @@ -794,9 +814,12 @@ def _create_variables(self, user_dtype, events): # Standard variables always present for event in events: - self.variables.add_array( - f"_{event}space", size=self._N + 1, dtype=np.int32, constant=False - ) + if event == 'spike': + self.variables.add_array("_spikespace", size=self._N + 1, dtype=np.int32, constant=False) + else: + eventspace_name = f"_{event.replace(' ', '_')}space" + self.variables.add_array(eventspace_name, size=self._N + 1, dtype=np.int32, constant=False) + # Add the special variable "i" which can be used to refer to the neuron index self.variables.add_arange("i", size=self._N, constant=True, read_only=True) # Add the clock variables @@ -836,12 +859,48 @@ def _create_variables(self, user_dtype, events): else: raise AssertionError(f"Unknown type of equation: {eq.eq_type}") - # Add the conditional-write attribute for variables with the - # "unless refractory" flag - if self._refractory is not False: - for eq in self.equations.values(): - if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags: - not_refractory_var = self.variables["not_refractory"] + # refractory variable setup for spike event + if 'spike' in self.events: + self.variables.add_array('lastspike', size=self._N, dtype=float, constant=False, value=-1e100) + if self._refractory.get('spike', False): + self.variables.add_subexpression('not_refractory', 'True') + # For other events + for event in self.events: + if event != 'spike' and event in self._refractory: + event_name = event.replace(' ', '_') + self.variables.add_array(f'_lastevent_{event_name}', size=self._N, dtype=float, + constant=False, value=-1e100) + refr = self._refractory[event] + if isinstance(refr, Quantity): + self.variables.add_array(f'_refractory_until_{event_name}', size=self._N, dtype=float, + constant=False, value=-1e100) + self.variables.add_subexpression(f'not_refractory_{event_name}', + f't >= _refractory_until_{event_name}') + elif isinstance(refr, str): + self.variables.add_subexpression(f'not_refractory_{event_name}', f'not ({refr})') + + + # Events without refractory + for event in self.events: + event_name = event.replace(' ', '_') + if event not in self._refractory: + self.variables.add_subexpression(f'not_refractory_{event_name}', 'True') + + if 'spike' in self.events and 'spike' in self._refractory: + refr = self._refractory['spike'] + if isinstance(refr, Quantity): + self.variables.add_array('_refractory_until', size=self._N, dtype=float, + constant=False, value=-1e100) + self.variables.add_subexpression('not_refractory', 't >= _refractory_until') + elif isinstance(refr, str): + self.variables.add_subexpression('not_refractory', f'not ({refr})') + + for eq in self.equations.values(): + if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags: + for event in self.events: + event_name = event.replace(' ', '_') + not_refractory_var = self.variables.get(f'not_refractory_{event_name}', None) + if not_refractory_var: var = self.variables[eq.varname] var.set_conditional_write(not_refractory_var)