Skip to content

Maintain backward compatibility for spike event variables in create_variables #1617

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 69 additions & 10 deletions brian2/groups/neurongroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def __init__(
method_options=None,
threshold=None,
reset=None,
refractory=False,
refractory=None, # Updated to None instead of False
events=None,
namespace=None,
dtype=None,
Expand Down Expand Up @@ -576,6 +576,26 @@ def __init__(
}
)

# Handle events
if events is None:
events = {}
self.events = {'spike': threshold} if threshold else {}
self.events.update(events)

# Handle refractory
if refractory is not None:
if isinstance(refractory, (str, Quantity)):
refractory = {'spike': refractory}
elif isinstance(refractory, dict):
for event in refractory:
if event not in self.events:
raise ValueError(f"Unknown event '{event}' in refractory dictionary.")
else:
raise TypeError("refractory must be a string, Quantity, or dictionary")
else:
refractory = {}
self._refractory = refractory

# add refractoriness
#: The original equations as specified by the user (i.e. without
#: the multiplied `int(not_refractory)` term for equations marked as
Expand Down Expand Up @@ -794,9 +814,12 @@ def _create_variables(self, user_dtype, events):

# Standard variables always present
for event in events:
self.variables.add_array(
f"_{event}space", size=self._N + 1, dtype=np.int32, constant=False
)
if event == 'spike':
self.variables.add_array("_spikespace", size=self._N + 1, dtype=np.int32, constant=False)
else:
eventspace_name = f"_{event.replace(' ', '_')}space"
self.variables.add_array(eventspace_name, size=self._N + 1, dtype=np.int32, constant=False)

# Add the special variable "i" which can be used to refer to the neuron index
self.variables.add_arange("i", size=self._N, constant=True, read_only=True)
# Add the clock variables
Expand Down Expand Up @@ -836,12 +859,48 @@ def _create_variables(self, user_dtype, events):
else:
raise AssertionError(f"Unknown type of equation: {eq.eq_type}")

# Add the conditional-write attribute for variables with the
# "unless refractory" flag
if self._refractory is not False:
for eq in self.equations.values():
if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags:
not_refractory_var = self.variables["not_refractory"]
# refractory variable setup for spike event
if 'spike' in self.events:
self.variables.add_array('lastspike', size=self._N, dtype=float, constant=False, value=-1e100)
if self._refractory.get('spike', False):
self.variables.add_subexpression('not_refractory', 'True')
# For other events
for event in self.events:
if event != 'spike' and event in self._refractory:
event_name = event.replace(' ', '_')
self.variables.add_array(f'_lastevent_{event_name}', size=self._N, dtype=float,
constant=False, value=-1e100)
refr = self._refractory[event]
if isinstance(refr, Quantity):
self.variables.add_array(f'_refractory_until_{event_name}', size=self._N, dtype=float,
constant=False, value=-1e100)
self.variables.add_subexpression(f'not_refractory_{event_name}',
f't >= _refractory_until_{event_name}')
elif isinstance(refr, str):
self.variables.add_subexpression(f'not_refractory_{event_name}', f'not ({refr})')


# Events without refractory
for event in self.events:
event_name = event.replace(' ', '_')
if event not in self._refractory:
self.variables.add_subexpression(f'not_refractory_{event_name}', 'True')

if 'spike' in self.events and 'spike' in self._refractory:
refr = self._refractory['spike']
if isinstance(refr, Quantity):
self.variables.add_array('_refractory_until', size=self._N, dtype=float,
constant=False, value=-1e100)
self.variables.add_subexpression('not_refractory', 't >= _refractory_until')
elif isinstance(refr, str):
self.variables.add_subexpression('not_refractory', f'not ({refr})')

for eq in self.equations.values():
if eq.type == DIFFERENTIAL_EQUATION and "unless refractory" in eq.flags:
for event in self.events:
event_name = event.replace(' ', '_')
not_refractory_var = self.variables.get(f'not_refractory_{event_name}', None)
if not_refractory_var:
var = self.variables[eq.varname]
var.set_conditional_write(not_refractory_var)

Expand Down
Loading