diff --git a/brian2/core/clocks.py b/brian2/core/clocks.py index 40cdd8d11..a6b7a7664 100644 --- a/brian2/core/clocks.py +++ b/brian2/core/clocks.py @@ -4,6 +4,8 @@ __docformat__ = "restructuredtext en" +from abc import ABC, abstractmethod + import numpy as np from brian2.core.names import Nameable @@ -11,9 +13,10 @@ from brian2.groups.group import VariableOwner from brian2.units.allunits import second from brian2.units.fundamentalunits import Quantity, check_units +from brian2.units.stdunits import ms from brian2.utils.logger import get_logger -__all__ = ["Clock", "defaultclock"] +__all__ = ["BaseClock", "Clock", "defaultclock", "EventClock"] logger = get_logger(__name__) @@ -62,31 +65,25 @@ def check_dt(new_dt, old_dt, target_t): ) -class Clock(VariableOwner): +class BaseClock(VariableOwner, ABC): """ - An object that holds the simulation time and the time step. + Abstract base class for all clocks in the simulator. + + This class should never be instantiated directly, use one of the subclasses + like Clock or EventClock instead. Parameters ---------- - dt : float - The time step of the simulation as a float name : str, optional An explicit name, if not specified gives an automatically generated name - - Notes - ----- - Clocks are run in the same `Network.run` iteration if `~Clock.t` is the - same. The condition for two - clocks to be considered as having the same time is - ``abs(t1-t2) other.variables["t"].get_value().item() + ) + + def __le__(self, other): + return self.__lt__(other) or self.same_time(other) + + def __ge__(self, other): + return self.__gt__(other) or self.same_time(other) + + @abstractmethod + def same_time(self, other): + """ + Check if two clocks are at the same time (within epsilon). + + Parameters + ---------- + other : BaseClock + The other clock to compare with + + Returns + ------- + bool + True if both clocks are at the same time + """ + pass + + +class EventClock(BaseClock): + """ + A clock that advances through a predefined sequence of times. + + Parameters + ---------- + times : array-like + The sequence of times for the clock to advance through + name : str, optional + An explicit name, if not specified gives an automatically generated name + """ + + def __init__(self, times, name="eventclock*"): + super().__init__(name=name) + times = Quantity(times) + from brian2.units.fundamentalunits import fail_for_dimension_mismatch + + fail_for_dimension_mismatch( + times, + second.dim, + error_message="'times' must have dimensions of time", + dim=times, + ) + self._times = sorted(times) + seen = set() + duplicates = [] + for time in self._times: + if float(time) in seen: + duplicates.append(time) + else: + seen.add(float(time)) + if duplicates: + raise ValueError( + "The times provided to EventClock must not contain duplicates. " + f"Duplicates found: {duplicates}" + ) + + self._times.append(np.inf * ms) self.variables.add_array( - "dt", + "times", dimensions=second.dim, - size=1, - values=float(dt), + size=len(self._times), + values=self._times, dtype=np.float64, read_only=True, - constant=True, - scalar=True, ) - self.variables.add_constant("N", value=1) - self._enable_group_attributes() - self.dt = dt - logger.diagnostic(f"Created clock {self.name} with dt={self.dt}") + self.variables["t"].set_value(self._times[0]) - @check_units(t=second) - def _set_t_update_dt(self, target_t=0 * second): - new_dt = self.dt_ - old_dt = self._old_dt - target_t = float(target_t) - if old_dt is not None and new_dt != old_dt: - self._old_dt = None - # Only allow a new dt which allows to correctly set the new time step - check_dt(new_dt, old_dt, target_t) + logger.diagnostic(f"Created event clock {self.name}") - new_timestep = self._calc_timestep(target_t) - # Since these attributes are read-only for normal users, we have to - # update them via the variables object directly - self.variables["timestep"].set_value(new_timestep) - self.variables["t"].set_value(new_timestep * new_dt) - logger.diagnostic(f"Setting Clock {self.name} to t={self.t}, dt={self.dt}") + def advance(self): + """ + Advance to the next time in the sequence. + """ + new_ts = self.variables["timestep"].get_value().item() + if self._i_end is not None and new_ts + 1 > self._i_end: + raise StopIteration( + "EventClock has reached the end of its available times." + ) + new_ts += 1 + self.variables["timestep"].set_value(new_ts) + self.variables["t"].set_value(self._times[new_ts]) - def _calc_timestep(self, target_t): + @check_units(start=second, end=second) + def set_interval(self, start, end): """ - Calculate the integer time step for the target time. If it cannot be - exactly represented (up to 0.01% of dt), round up. + Set the start and end time of the simulation. Parameters ---------- - target_t : float - The target time in seconds + start : second + The start time of the simulation + end : second + The end time of the simulation + """ + start = float(start) + end = float(end) + + start_idx = np.searchsorted(self._times, start) + end_idx = np.searchsorted(self._times, end) + + self.variables["timestep"].set_value(start_idx) + self.variables["t"].set_value(self._times[start_idx]) + + self._i_end = end_idx + + def __getitem__(self, timestep): + """ + Get the time at a specific timestep. + + Parameters + ---------- + timestep : int + The timestep to get the time for Returns ------- - timestep : int - The target time in integers (based on dt) + float + The time at the specified timestep """ - new_i = np.int64(np.round(target_t / self.dt_)) - new_t = new_i * self.dt_ - if new_t == target_t or np.abs(new_t - target_t) / self.dt_ <= Clock.epsilon_dt: - new_timestep = new_i + return self._times[timestep] + + def same_time(self, other): + """ + Check if two clocks are at the same time. + + For comparisons with `Clock` objects, uses the Clock's dt and epsilon_dt. + For comparisons with other `EventClock` or `BaseClock` objects, uses the base + epsilon value. + + Parameters + ---------- + other : BaseClock + The other clock to compare with + + Returns + ------- + bool + True if both clocks are at the same time + """ + t1 = self.variables["t"].get_value().item() + t2 = other.variables["t"].get_value().item() + + if isinstance(other, Clock): + return abs(t1 - t2) / other.dt_ < other.epsilon_dt else: - new_timestep = np.int64(np.ceil(target_t / self.dt_)) - return new_timestep + # Both are pure EventClocks without dt. + return abs(t1 - t2) < self.epsilon + + def __le__(self, other): + return self.__lt__(other) or self.same_time(other) + + def __ge__(self, other): + return self.__gt__(other) or self.same_time(other) + + +class Clock(BaseClock): + """ + An object that holds the simulation time and the time step. + + Parameters + ---------- + dt : float + The time step of the simulation as a float + name : str, optional + An explicit name, if not specified gives an automatically generated name + + Notes + ----- + Clocks are run in the same `Network.run` iteration if `~Clock.t` is the + same. The condition for two + clocks to be considered as having the same time is + ``abs(t1-t2) self._i_end: + raise StopIteration("Clock has reached the end of its available times.") + + self.variables["timestep"].set_value(new_ts) + new_t = new_ts * self.dt_ + self.variables["t"].set_value(new_t) + def _get_dt_(self): return self.variables["dt"].get_value().item() @@ -180,20 +376,75 @@ def _set_dt(self, dt): doc="""The time step of the simulation as a float (in seconds)""", ) + def _calc_timestep(self, target_t): + """ + Calculate the integer time step for the target time. If it cannot be + exactly represented (up to epsilon_dt of dt), round up. + + Parameters + ---------- + target_t : float + The target time in seconds + + Returns + ------- + timestep : int + The target time in integers (based on dt) + """ + new_i = np.int64(np.round(target_t / self.dt_)) + new_t = new_i * self.dt_ + if new_t == target_t or np.abs(new_t - target_t) / self.dt_ <= Clock.epsilon_dt: + new_timestep = new_i + else: + new_timestep = np.int64(np.ceil(target_t / self.dt_)) + return new_timestep + + @check_units(target_t=second) + def _set_t_update_dt(self, target_t=0 * second): + """ + Set the time to a specific value, checking if dt has changed. + + Parameters + ---------- + target_t : second + The target time to set + """ + new_dt = self.dt_ + old_dt = self._old_dt + target_t = float(target_t) + + if old_dt is not None and new_dt != old_dt: + self._old_dt = None + check_dt(new_dt, old_dt, target_t) + + new_timestep = self._calc_timestep(target_t) + + self.variables["timestep"].set_value(new_timestep) + self.variables["t"].set_value(new_timestep * self.dt_) + set_t = self.variables["t"].get_value().item() + + logger.diagnostic(f"Setting Clock {self.name} to t={set_t}, dt={new_dt}") + @check_units(start=second, end=second) def set_interval(self, start, end): """ - set_interval(self, start, end) - Set the start and end time of the simulation. Sets the start and end value of the clock precisely if - possible (using epsilon) or rounding up if not. This assures that + possible (using epsilon_dt) or rounding up if not. This assures that multiple calls to `Network.run` will not re-run the same time step. + + Parameters + ---------- + start : second + The start time of the simulation + end : second + The end time of the simulation """ self._set_t_update_dt(target_t=start) end = float(end) self._i_end = self._calc_timestep(end) + if self._i_end > 2**40: logger.warn( "The end time of the simulation has been set to " @@ -206,9 +457,35 @@ def set_interval(self, start, end): "many_timesteps", ) - #: The relative difference for times (in terms of dt) so that they are - #: considered identical. - epsilon_dt = 1e-4 + def same_time(self, other): + """ + Check if two clocks are at the same time (within epsilon_dt * dt). + + Parameters + ---------- + other : BaseClock + The other clock to compare with + + Returns + ------- + bool + True if both clocks are at the same time + """ + t1 = self.variables["t"].get_value().item() + t2 = other.variables["t"].get_value().item() + + if isinstance(other, Clock): + # Both are pure Clocks with dt so we take the min. + dt = min(self.dt_, other.dt_) + return abs(t1 - t2) / dt < self.epsilon_dt + else: + return abs(t1 - t2) / self.dt_ < self.epsilon_dt + + def __le__(self, other): + return self.__lt__(other) or self.same_time(other) + + def __ge__(self, other): + return self.__gt__(other) or self.same_time(other) class DefaultClockProxy: diff --git a/brian2/core/network.py b/brian2/core/network.py index 34df6345b..8f290b1d8 100644 --- a/brian2/core/network.py +++ b/brian2/core/network.py @@ -17,7 +17,7 @@ from collections.abc import Mapping, Sequence from brian2.core.base import BrianObject, BrianObjectException -from brian2.core.clocks import Clock, defaultclock +from brian2.core.clocks import defaultclock from brian2.core.names import Nameable from brian2.core.namespace import get_local_namespace from brian2.core.preferences import BrianPreference, prefs @@ -1019,7 +1019,7 @@ def before_run(self, run_namespace): "creating a new one." ) - clocknames = ", ".join(f"{obj.name} (dt={obj.dt})" for obj in self._clocks) + clocknames = ", ".join(f"{obj.name}" for obj in self._clocks) logger.debug( f"Network '{self.name}' uses {len(self._clocks)} clocks: {clocknames}", "before_run", @@ -1035,19 +1035,11 @@ def after_run(self): obj.after_run() def _nextclocks(self): - clocks_times_dt = [ - (c, self._clock_variables[c][1][0], self._clock_variables[c][2][0]) - for c in self._clocks - ] - minclock, min_time, minclock_dt = min(clocks_times_dt, key=lambda k: k[1]) - curclocks = { - clock - for clock, time, dt in clocks_times_dt - if ( - time == min_time - or abs(time - min_time) / min(minclock_dt, dt) < Clock.epsilon_dt - ) - } + + minclock = min(self._clocks, key=lambda c: c.variables["t"].get_value().item()) + + curclocks = {clock for clock in self._clocks if clock.same_time(minclock)} + return minclock, curclocks @device_override("network_run") @@ -1136,7 +1128,6 @@ def run( c: ( c.variables["timestep"].get_value(), c.variables["t"].get_value(), - c.variables["dt"].get_value(), ) for c in self._clocks } @@ -1183,15 +1174,15 @@ def run( profiling_info = defaultdict(float) if single_clock: - timestep, t, dt = ( + timestep, t = ( clock.variables["timestep"].get_value(), clock.variables["t"].get_value(), - clock.variables["dt"].get_value(), ) + else: # Find the first clock to be updated (see note below) clock, curclocks = self._nextclocks() - timestep, _, _ = self._clock_variables[clock] + timestep, t = self._clock_variables[clock] running = timestep[0] < clock._i_end @@ -1199,7 +1190,7 @@ def run( while running and not self._stopped and not Network._globally_stopped: if not single_clock: - timestep, t, dt = self._clock_variables[clock] + timestep, t = self._clock_variables[clock] # update the network time to this clock's time self.t_ = t[0] if report is not None: @@ -1224,8 +1215,8 @@ def run( for obj in active_objects: obj.run() - timestep[0] += 1 - t[0] = timestep[0] * dt[0] + clock.advance() + else: if profile: for obj in active_objects: @@ -1239,15 +1230,13 @@ def run( obj.run() for c in curclocks: - timestep, t, dt = self._clock_variables[c] - timestep[0] += 1 - t[0] = timestep[0] * dt[0] + c.advance() # find the next clocks to be updated. The < operator for Clock # determines that the first clock to be updated should be the one # with the smallest t value, unless there are several with the # same t value in which case we update all of them clock, curclocks = self._nextclocks() - timestep, _, _ = self._clock_variables[clock] + timestep, t = self._clock_variables[clock] if ( device._maximum_run_time is not None diff --git a/brian2/groups/group.py b/brian2/groups/group.py index 4e13b1a9e..c364f4a88 100644 --- a/brian2/groups/group.py +++ b/brian2/groups/group.py @@ -1169,12 +1169,60 @@ def run_regularly( obj : `CodeRunner` A reference to the object that will be run. """ + from brian2.core.clocks import Clock # Avoid circular import + if name is None: names = [o.name for o in self.contained_objects] name = find_name(f"{self.name}_run_regularly*", names) if dt is None and clock is None: clock = self._clock + elif clock is None: + clock = Clock(dt=dt) + + return self.run_on_clock(code, clock, when, order, name, codeobj_class) + + def run_on_clock( + self, + code, + clock=None, + when="start", + order=0, + name=None, + codeobj_class=None, + ): + """ + This method is used by `run_regularly` and `run_at` to register operations + that are executed at times determined by a user-supplied or internally generated `Clock`. + The resulting `CodeRunner` is automatically added to the group. + + Parameters + ---------- + code : str + The abstract code to run. + clock : `Clock`, optional + The update clock to use for this operation. If neither a clock nor + the `dt` argument is specified, defaults to the clock of the group. + when : str, optional + When to run within a time step, defaults to the ``'start'`` slot. + See :ref:`scheduling` for possible values. + name : str, optional + A unique name, if non is given the name of the group appended with + 'run_custom_clock', 'run_custom_clock_1', etc. will be used. If a + name is given explicitly, it will be used as given (i.e. the group + name will not be prepended automatically). + codeobj_class : class, optional + The `CodeObject` class to run code with. If not specified, defaults + to the `group`'s ``codeobj_class`` attribute. + + Returns + ------- + obj : `CodeRunner` + A reference to the object that will be run. + """ + if name is None: + names = [o.name for o in self.contained_objects] + name = find_name(f"{self.name}_run_custom*", names) # Subgroups are normally not included in their parent's # contained_objects list, since there's no need to include them in the @@ -1195,7 +1243,6 @@ def run_regularly( "stateupdate", code=code, name=name, - dt=dt, clock=clock, when=when, order=order, @@ -1204,6 +1251,53 @@ def run_regularly( self.contained_objects.append(runner) return runner + def run_at( + self, + code, + times, + when="start", + order=0, + name=None, + codeobj_class=None, + ): + """ + Run abstract code in the group's namespace. The created `CodeRunner` + object will be automatically added to the group, it therefore does not + need to be added to the network manually. However, a reference to the + object will be returned, which can be used to later remove it from the + group or to set it to inactive. + + Parameters + ---------- + code : str + The abstract code to run. + times : array-like + The specific simulation times at which to execute the code. + when : str, optional + When to run within a time step, defaults to the ``'start'`` slot. + See :ref:`scheduling` for possible values. + name : str, optional + A unique name, if non is given the name of the group appended with + 'run_at', 'run_at_1', etc. will be used. If a + name is given explicitly, it will be used as given (i.e. the group + name will not be prepended automatically). + codeobj_class : class, optional + The `CodeObject` class to run code with. If not specified, defaults + to the `group`'s ``codeobj_class`` attribute. + + Returns + ------- + obj : `CodeRunner` + A reference to the object that will be run. + """ + if name is None: + names = [o.name for o in self.contained_objects] + name = find_name(f"{self.name}_run_at*", names) + from brian2.core.clocks import EventClock + + clock = EventClock(times) + return self.run_on_clock(code, clock, when, order, name, codeobj_class) + def _check_for_invalid_states(self): """ Checks if any state variables updated by differential equations have diff --git a/brian2/tests/test_clocks.py b/brian2/tests/test_clocks.py index ef9249dd5..189742c7c 100644 --- a/brian2/tests/test_clocks.py +++ b/brian2/tests/test_clocks.py @@ -2,6 +2,9 @@ from numpy.testing import assert_array_equal, assert_equal from brian2 import * +from brian2.core.clocks import EventClock +from brian2.tests.test_network import NameLister +from brian2.units.fundamentalunits import DimensionMismatchError from brian2.utils.logger import catch_logs @@ -57,6 +60,74 @@ def test_set_interval_warning(): assert logs[0][1].endswith("many_timesteps") +@pytest.mark.codegen_independent +def test_event_clock(): + times = [0.0 * ms, 0.3 * ms, 0.5 * ms, 0.6 * ms] + event_clock = EventClock(times) + + for i in range(4): + print(event_clock[i]) + + assert_equal(event_clock.variables["t"].get_value(), 0.0 * ms) + assert_equal(event_clock[1], 0.3 * ms) + + event_clock.advance() + assert_equal(event_clock.variables["timestep"].get_value(), 1) + assert_equal(event_clock.variables["t"].get_value(), 0.0003) + + event_clock.set_interval(0.3 * ms, 0.6 * ms) + assert_equal(event_clock.variables["timestep"].get_value(), 1) + assert_equal(event_clock.variables["t"].get_value(), 0.0003) + event_clock.advance() + event_clock.advance() + + with pytest.raises(StopIteration): + event_clock.advance() + + invalid_times = [0.0 * volt, 0.5 * volt] + with pytest.raises(DimensionMismatchError) as excinfo: + EventClock(invalid_times) + + +@pytest.mark.codegen_independent +def test_combined_clocks_with_run_at(): + + # Reset updates + NameLister.updates[:] = [] + + # Regular NameLister at 1ms interval + regular_lister = NameLister(name="x", dt=1 * ms, order=0) + + # Event NameLister at specific times + event_times = [0.5 * ms, 2.5 * ms, 4 * ms] + event_lister = NameLister(name="y", clock=EventClock(times=event_times), order=1) + + # Create and run the network + net = Network(regular_lister, event_lister) + net.run(5 * ms) + + # Get update string + updates = "".join(NameLister.updates) + + # Expected output: "x" at 0,1,2,3,4ms = 5 times + # "y" at 0.5, 2.5, 4.0ms = 3 times + expected_x_count = 5 + expected_y_count = 3 + + x_count = updates.count("x") + y_count = updates.count("y") + + assert ( + x_count == expected_x_count + ), f"Expected {expected_x_count} x's, got {x_count}" + assert ( + y_count == expected_y_count + ), f"Expected {expected_y_count} y's, got {y_count}" + + # Optional: check full string if needed + print(updates) + + if __name__ == "__main__": test_clock_attributes() restore_initial_state() @@ -64,3 +135,5 @@ def test_set_interval_warning(): restore_initial_state() test_defaultclock() test_set_interval_warning() + test_event_clock() + test_combined_clocks_with_run_at() diff --git a/brian2/tests/test_network.py b/brian2/tests/test_network.py index 9b623fd36..5b9385ff3 100644 --- a/brian2/tests/test_network.py +++ b/brian2/tests/test_network.py @@ -1676,12 +1676,14 @@ def test_small_runs(): # One long run and multiple small runs should give the same results group_1 = NeuronGroup(10, "dv/dt = -v / (10*ms) : 1") group_1.v = "(i + 1) / N" + group_1.run_at("v += 0.1", times=[100 * ms, 300 * ms]) mon_1 = StateMonitor(group_1, "v", record=True) net_1 = Network(group_1, mon_1) net_1.run(1 * second) group_2 = NeuronGroup(10, "dv/dt = -v / (10*ms) : 1") group_2.v = "(i + 1) / N" + group_2.run_at("v += 0.1", times=[100 * ms, 300 * ms]) mon_2 = StateMonitor(group_2, "v", record=True) net_2 = Network(group_2, mon_2) runtime = 1 * ms