diff --git a/brian2/codegen/cpp_prefs.py b/brian2/codegen/cpp_prefs.py index 87bb3b57f..aeb468854 100644 --- a/brian2/codegen/cpp_prefs.py +++ b/brian2/codegen/cpp_prefs.py @@ -22,9 +22,9 @@ from setuptools.msvc import msvc14_get_vc_env as _get_vc_env except ImportError: # Setuptools 0.74.0 removed this function try: - from distutils._msvccompiler import _get_vc_env - except ImportError: # Things keep moving around in distutils/setuptools from distutils.compilers.C.msvc import _get_vc_env + except ImportError: # Things keep moving around in distutils/setuptools + from distutils._msvccompiler import _get_vc_env from distutils.ccompiler import get_default_compiler diff --git a/brian2/core/clocks.py b/brian2/core/clocks.py index 40cdd8d11..7ab0d7978 100644 --- a/brian2/core/clocks.py +++ b/brian2/core/clocks.py @@ -4,6 +4,8 @@ __docformat__ = "restructuredtext en" +from abc import ABC, abstractmethod + import numpy as np from brian2.core.names import Nameable @@ -11,9 +13,10 @@ from brian2.groups.group import VariableOwner from brian2.units.allunits import second from brian2.units.fundamentalunits import Quantity, check_units +from brian2.units.stdunits import ms from brian2.utils.logger import get_logger -__all__ = ["Clock", "defaultclock"] +__all__ = ["BaseClock", "Clock", "defaultclock", "EventClock"] logger = get_logger(__name__) @@ -62,31 +65,25 @@ def check_dt(new_dt, old_dt, target_t): ) -class Clock(VariableOwner): +class BaseClock(VariableOwner, ABC): """ - An object that holds the simulation time and the time step. + Abstract base class for all clocks in the simulator. + + This class should never be instantiated directly, use one of the subclasses + like Clock or EventClock instead. Parameters ---------- - dt : float - The time step of the simulation as a float name : str, optional An explicit name, if not specified gives an automatically generated name - - Notes - ----- - Clocks are run in the same `Network.run` iteration if `~Clock.t` is the - same. The condition for two - clocks to be considered as having the same time is - ``abs(t1-t2) other.variables["t"].get_value().item() + ) + + def __le__(self, other): + return self.__lt__(other) or self.same_time(other) + + def __ge__(self, other): + return self.__gt__(other) or self.same_time(other) + + @abstractmethod + def same_time(self, other): + """ + Check if two clocks are at the same time (within epsilon). + + Parameters + ---------- + other : BaseClock + The other clock to compare with + + Returns + ------- + bool + True if both clocks are at the same time + """ + pass + + +class EventClock(BaseClock): + """ + A clock that advances through a predefined sequence of times. + + Parameters + ---------- + times : array-like + The sequence of times for the clock to advance through + name : str, optional + An explicit name, if not specified gives an automatically generated name + """ + + def __init__(self, times, name="eventclock*"): + super().__init__(name=name) + times = Quantity(times) + from brian2.units.fundamentalunits import fail_for_dimension_mismatch + + fail_for_dimension_mismatch( + times, + second.dim, + error_message="'times' must have dimensions of time", + dim=times, + ) + + times_array = np.asarray(times, dtype=float) + unique_times = np.unique(times_array) + if len(unique_times) != len(times_array): + raise ValueError( + "The times provided to EventClock must not contain duplicates." + ) + + self._times = sorted(times) + self._times.append(np.inf * ms) self.variables.add_array( - "dt", + "times", dimensions=second.dim, - size=1, - values=float(dt), + size=len(self._times), + values=self._times, dtype=np.float64, read_only=True, - constant=True, - scalar=True, ) - self.variables.add_constant("N", value=1) - self._enable_group_attributes() - self.dt = dt - logger.diagnostic(f"Created clock {self.name} with dt={self.dt}") + self.variables["t"].set_value(self._times[0]) - @check_units(t=second) - def _set_t_update_dt(self, target_t=0 * second): - new_dt = self.dt_ - old_dt = self._old_dt - target_t = float(target_t) - if old_dt is not None and new_dt != old_dt: - self._old_dt = None - # Only allow a new dt which allows to correctly set the new time step - check_dt(new_dt, old_dt, target_t) + logger.diagnostic(f"Created event clock {self.name}") - new_timestep = self._calc_timestep(target_t) - # Since these attributes are read-only for normal users, we have to - # update them via the variables object directly - self.variables["timestep"].set_value(new_timestep) - self.variables["t"].set_value(new_timestep * new_dt) - logger.diagnostic(f"Setting Clock {self.name} to t={self.t}, dt={self.dt}") + def advance(self): + """ + Advance to the next time in the sequence. + """ + new_ts = self.variables["timestep"].get_value().item() + if self._i_end is not None and new_ts + 1 > self._i_end: + raise StopIteration( + "EventClock has reached the end of its available times." + ) + new_ts += 1 + self.variables["timestep"].set_value(new_ts) + self.variables["t"].set_value(self._times[new_ts]) - def _calc_timestep(self, target_t): + @check_units(start=second, end=second) + def set_interval(self, start, end): """ - Calculate the integer time step for the target time. If it cannot be - exactly represented (up to 0.01% of dt), round up. + Set the start and end time of the simulation. Parameters ---------- - target_t : float - The target time in seconds + start : second + The start time of the simulation + end : second + The end time of the simulation + """ + start = float(start) + end = float(end) + + start_idx = np.searchsorted(self._times, start) + end_idx = np.searchsorted(self._times, end) + + self.variables["timestep"].set_value(start_idx) + self.variables["t"].set_value(self._times[start_idx]) + + self._i_end = end_idx + + def __getitem__(self, timestep): + """ + Get the time at a specific timestep. + + Parameters + ---------- + timestep : int + The timestep to get the time for Returns ------- - timestep : int - The target time in integers (based on dt) + float + The time at the specified timestep """ - new_i = np.int64(np.round(target_t / self.dt_)) - new_t = new_i * self.dt_ - if new_t == target_t or np.abs(new_t - target_t) / self.dt_ <= Clock.epsilon_dt: - new_timestep = new_i + return self._times[timestep] + + def same_time(self, other): + """ + Check if two clocks are at the same time. + + For comparisons with `Clock` objects, uses the Clock's dt and epsilon_dt. + For comparisons with other `EventClock` or `BaseClock` objects, uses the base + epsilon value. + + Parameters + ---------- + other : BaseClock + The other clock to compare with + + Returns + ------- + bool + True if both clocks are at the same time + """ + t1 = self.variables["t"].get_value().item() + t2 = other.variables["t"].get_value().item() + + if isinstance(other, Clock): + return abs(t1 - t2) / other.dt_ < other.epsilon_dt else: - new_timestep = np.int64(np.ceil(target_t / self.dt_)) - return new_timestep + # Both are pure EventClocks without dt. + return abs(t1 - t2) < self.epsilon + + def __le__(self, other): + return self.__lt__(other) or self.same_time(other) + + def __ge__(self, other): + return self.__gt__(other) or self.same_time(other) + + +class Clock(BaseClock): + """ + An object that holds the simulation time and the time step. + + Parameters + ---------- + dt : float + The time step of the simulation as a float + name : str, optional + An explicit name, if not specified gives an automatically generated name + + Notes + ----- + Clocks are run in the same `Network.run` iteration if `~Clock.t` is the + same. The condition for two + clocks to be considered as having the same time is + ``abs(t1-t2) self._i_end: + raise StopIteration("Clock has reached the end of its available times.") + + self.variables["timestep"].set_value(new_ts) + new_t = new_ts * self.dt_ + self.variables["t"].set_value(new_t) + def _get_dt_(self): return self.variables["dt"].get_value().item() @@ -180,20 +371,75 @@ def _set_dt(self, dt): doc="""The time step of the simulation as a float (in seconds)""", ) + def _calc_timestep(self, target_t): + """ + Calculate the integer time step for the target time. If it cannot be + exactly represented (up to epsilon_dt of dt), round up. + + Parameters + ---------- + target_t : float + The target time in seconds + + Returns + ------- + timestep : int + The target time in integers (based on dt) + """ + new_i = np.int64(np.round(target_t / self.dt_)) + new_t = new_i * self.dt_ + if new_t == target_t or np.abs(new_t - target_t) / self.dt_ <= Clock.epsilon_dt: + new_timestep = new_i + else: + new_timestep = np.int64(np.ceil(target_t / self.dt_)) + return new_timestep + + @check_units(target_t=second) + def _set_t_update_dt(self, target_t=0 * second): + """ + Set the time to a specific value, checking if dt has changed. + + Parameters + ---------- + target_t : second + The target time to set + """ + new_dt = self.dt_ + old_dt = self._old_dt + target_t = float(target_t) + + if old_dt is not None and new_dt != old_dt: + self._old_dt = None + check_dt(new_dt, old_dt, target_t) + + new_timestep = self._calc_timestep(target_t) + + self.variables["timestep"].set_value(new_timestep) + self.variables["t"].set_value(new_timestep * self.dt_) + set_t = self.variables["t"].get_value().item() + + logger.diagnostic(f"Setting Clock {self.name} to t={set_t}, dt={new_dt}") + @check_units(start=second, end=second) def set_interval(self, start, end): """ - set_interval(self, start, end) - Set the start and end time of the simulation. Sets the start and end value of the clock precisely if - possible (using epsilon) or rounding up if not. This assures that + possible (using epsilon_dt) or rounding up if not. This assures that multiple calls to `Network.run` will not re-run the same time step. + + Parameters + ---------- + start : second + The start time of the simulation + end : second + The end time of the simulation """ self._set_t_update_dt(target_t=start) end = float(end) self._i_end = self._calc_timestep(end) + if self._i_end > 2**40: logger.warn( "The end time of the simulation has been set to " @@ -206,9 +452,35 @@ def set_interval(self, start, end): "many_timesteps", ) - #: The relative difference for times (in terms of dt) so that they are - #: considered identical. - epsilon_dt = 1e-4 + def same_time(self, other): + """ + Check if two clocks are at the same time (within epsilon_dt * dt). + + Parameters + ---------- + other : BaseClock + The other clock to compare with + + Returns + ------- + bool + True if both clocks are at the same time + """ + t1 = self.variables["t"].get_value().item() + t2 = other.variables["t"].get_value().item() + + if isinstance(other, Clock): + # Both are pure Clocks with dt so we take the min. + dt = min(self.dt_, other.dt_) + return abs(t1 - t2) / dt < self.epsilon_dt + else: + return abs(t1 - t2) / self.dt_ < self.epsilon_dt + + def __le__(self, other): + return self.__lt__(other) or self.same_time(other) + + def __ge__(self, other): + return self.__gt__(other) or self.same_time(other) class DefaultClockProxy: diff --git a/brian2/core/network.py b/brian2/core/network.py index 34df6345b..8f290b1d8 100644 --- a/brian2/core/network.py +++ b/brian2/core/network.py @@ -17,7 +17,7 @@ from collections.abc import Mapping, Sequence from brian2.core.base import BrianObject, BrianObjectException -from brian2.core.clocks import Clock, defaultclock +from brian2.core.clocks import defaultclock from brian2.core.names import Nameable from brian2.core.namespace import get_local_namespace from brian2.core.preferences import BrianPreference, prefs @@ -1019,7 +1019,7 @@ def before_run(self, run_namespace): "creating a new one." ) - clocknames = ", ".join(f"{obj.name} (dt={obj.dt})" for obj in self._clocks) + clocknames = ", ".join(f"{obj.name}" for obj in self._clocks) logger.debug( f"Network '{self.name}' uses {len(self._clocks)} clocks: {clocknames}", "before_run", @@ -1035,19 +1035,11 @@ def after_run(self): obj.after_run() def _nextclocks(self): - clocks_times_dt = [ - (c, self._clock_variables[c][1][0], self._clock_variables[c][2][0]) - for c in self._clocks - ] - minclock, min_time, minclock_dt = min(clocks_times_dt, key=lambda k: k[1]) - curclocks = { - clock - for clock, time, dt in clocks_times_dt - if ( - time == min_time - or abs(time - min_time) / min(minclock_dt, dt) < Clock.epsilon_dt - ) - } + + minclock = min(self._clocks, key=lambda c: c.variables["t"].get_value().item()) + + curclocks = {clock for clock in self._clocks if clock.same_time(minclock)} + return minclock, curclocks @device_override("network_run") @@ -1136,7 +1128,6 @@ def run( c: ( c.variables["timestep"].get_value(), c.variables["t"].get_value(), - c.variables["dt"].get_value(), ) for c in self._clocks } @@ -1183,15 +1174,15 @@ def run( profiling_info = defaultdict(float) if single_clock: - timestep, t, dt = ( + timestep, t = ( clock.variables["timestep"].get_value(), clock.variables["t"].get_value(), - clock.variables["dt"].get_value(), ) + else: # Find the first clock to be updated (see note below) clock, curclocks = self._nextclocks() - timestep, _, _ = self._clock_variables[clock] + timestep, t = self._clock_variables[clock] running = timestep[0] < clock._i_end @@ -1199,7 +1190,7 @@ def run( while running and not self._stopped and not Network._globally_stopped: if not single_clock: - timestep, t, dt = self._clock_variables[clock] + timestep, t = self._clock_variables[clock] # update the network time to this clock's time self.t_ = t[0] if report is not None: @@ -1224,8 +1215,8 @@ def run( for obj in active_objects: obj.run() - timestep[0] += 1 - t[0] = timestep[0] * dt[0] + clock.advance() + else: if profile: for obj in active_objects: @@ -1239,15 +1230,13 @@ def run( obj.run() for c in curclocks: - timestep, t, dt = self._clock_variables[c] - timestep[0] += 1 - t[0] = timestep[0] * dt[0] + c.advance() # find the next clocks to be updated. The < operator for Clock # determines that the first clock to be updated should be the one # with the smallest t value, unless there are several with the # same t value in which case we update all of them clock, curclocks = self._nextclocks() - timestep, _, _ = self._clock_variables[clock] + timestep, t = self._clock_variables[clock] if ( device._maximum_run_time is not None diff --git a/brian2/devices/cpp_standalone/brianlib/clocks.h b/brian2/devices/cpp_standalone/brianlib/clocks.h index d559bfb35..1898bf2f5 100644 --- a/brian2/devices/cpp_standalone/brianlib/clocks.h +++ b/brian2/devices/cpp_standalone/brianlib/clocks.h @@ -2,6 +2,7 @@ #define _BRIAN_CLOCKS_H #include #include +#include #include #include @@ -12,20 +13,29 @@ namespace { }; }; -class Clock +class BaseClock { public: - double epsilon; - double *dt; int64_t *timestep; double *t; + virtual void tick() = 0; + virtual void set_interval(double start, double end) = 0; + inline bool running() { return timestep[0]=1700) objects.push_back(std::make_pair(std::move(clock), std::move(func))); @@ -46,7 +46,7 @@ void Network::run(const double duration, void (*report_func)(const double, const compute_clocks(); // set interval for all clocks - for(std::set::iterator i=clocks.begin(); i!=clocks.end(); i++) + for(std::set::iterator i=clocks.begin(); i!=clocks.end(); i++) (*i)->set_interval(t, t_end); {% if openmp_pragma('with_openmp') %} @@ -59,7 +59,7 @@ void Network::run(const double duration, void (*report_func)(const double, const report_func(0.0, 0.0, t_start, duration); } - Clock* clock = next_clocks(); + BaseClock* clock = next_clocks(); double elapsed_realtime; bool did_break_early = false; @@ -83,7 +83,7 @@ void Network::run(const double duration, void (*report_func)(const double, const next_report_time += report_period; } } - Clock *obj_clock = objects[i].first; + BaseClock *obj_clock = objects[i].first; // Only execute the object if it uses the right clock for this step if (curclocks.find(obj_clock) != curclocks.end()) { @@ -92,7 +92,7 @@ void Network::run(const double duration, void (*report_func)(const double, const func(); } } - for(std::set::iterator i=curclocks.begin(); i!=curclocks.end(); i++) + for(std::set::iterator i=curclocks.begin(); i!=curclocks.end(); i++) (*i)->tick(); clock = next_clocks(); @@ -133,21 +133,21 @@ void Network::compute_clocks() clocks.clear(); for(int i=0; i::iterator i=clocks.begin(); i!=clocks.end(); i++) + for(std::set::iterator i=clocks.begin(); i!=clocks.end(); i++) { - Clock *clock = *i; + BaseClock *clock = *i; if(clock->t[0]t[0]) minclock = clock; } @@ -155,9 +155,9 @@ Clock* Network::next_clocks() curclocks.clear(); double t = minclock->t[0]; - for(std::set::iterator i=clocks.begin(); i!=clocks.end(); i++) + for(std::set::iterator i=clocks.begin(); i!=clocks.end(); i++) { - Clock *clock = *i; + BaseClock *clock = *i; double s = clock->t[0]; if(s==t || fabs(s-t)<=Clock_epsilon) curclocks.insert(clock); @@ -181,18 +181,18 @@ typedef void (*codeobj_func)(); class Network { - std::set clocks, curclocks; + std::set clocks, curclocks; void compute_clocks(); - Clock* next_clocks(); + BaseClock* next_clocks(); public: - std::vector< std::pair< Clock*, codeobj_func > > objects; + std::vector< std::pair< BaseClock*, codeobj_func > > objects; double t; static double _last_run_time; static double _last_run_completed_fraction; Network(); void clear(); - void add(Clock *clock, codeobj_func func); + void add(BaseClock *clock, codeobj_func func); void run(const double duration, void (*report_func)(const double, const double, const double, const double), const double report_period); }; diff --git a/brian2/devices/cpp_standalone/templates/objects.cpp b/brian2/devices/cpp_standalone/templates/objects.cpp index fab6bf3a0..f8d6a69ba 100644 --- a/brian2/devices/cpp_standalone/templates/objects.cpp +++ b/brian2/devices/cpp_standalone/templates/objects.cpp @@ -177,8 +177,13 @@ SynapticPathway {{path.name}}( {% endfor %} //////////////// clocks /////////////////// +// attributes will be set in run.cpp {% for clock in clocks | sort(attribute='name') %} -Clock {{clock.name}}; // attributes will be set in run.cpp +{% if clock.__class__.__name__ == "EventClock" %} +EventClock {{clock.name}}; +{% else %} +Clock {{clock.name}}; +{% endif %} {% endfor %} {% if profiled_codeobjects is defined %} @@ -438,7 +443,11 @@ extern std::vector< RandomGenerator > _random_generators; //////////////// clocks /////////////////// {% for clock in clocks | sort(attribute='name') %} +{% if clock.__class__.__name__ == "EventClock" %} +extern EventClock {{clock.name}}; +{% else %} extern Clock {{clock.name}}; +{% endif %} {% endfor %} //////////////// networks ///////////////// diff --git a/brian2/devices/cpp_standalone/templates/run.cpp b/brian2/devices/cpp_standalone/templates/run.cpp index 798493dd0..b52dfcba2 100644 --- a/brian2/devices/cpp_standalone/templates/run.cpp +++ b/brian2/devices/cpp_standalone/templates/run.cpp @@ -18,9 +18,14 @@ void brian_start() _load_arrays(); // Initialize clocks (link timestep and dt to the respective arrays) {% for clock in clocks | sort(attribute='name') %} - brian::{{clock.name}}.timestep = brian::{{array_specs[clock.variables['timestep']]}}; - brian::{{clock.name}}.dt = brian::{{array_specs[clock.variables['dt']]}}; + brian::{{clock.name}}.timestep = brian::{{array_specs[clock.variables['timestep']]}}; brian::{{clock.name}}.t = brian::{{array_specs[clock.variables['t']]}}; + {% if clock.__class__.__name__ == "EventClock" %} {# FIXME: A bit ugly... #} + brian::{{clock.name}}.times = brian::{{array_specs[clock.variables['times']]}}; + brian::{{clock.name}}.n_times = {{clock.variables['times'].size}}; + {% else %} + brian::{{clock.name}}.dt = brian::{{array_specs[clock.variables['dt']]}}; + {% endif %} {% endfor %} } diff --git a/brian2/groups/group.py b/brian2/groups/group.py index 4e13b1a9e..c364f4a88 100644 --- a/brian2/groups/group.py +++ b/brian2/groups/group.py @@ -1169,12 +1169,60 @@ def run_regularly( obj : `CodeRunner` A reference to the object that will be run. """ + from brian2.core.clocks import Clock # Avoid circular import + if name is None: names = [o.name for o in self.contained_objects] name = find_name(f"{self.name}_run_regularly*", names) if dt is None and clock is None: clock = self._clock + elif clock is None: + clock = Clock(dt=dt) + + return self.run_on_clock(code, clock, when, order, name, codeobj_class) + + def run_on_clock( + self, + code, + clock=None, + when="start", + order=0, + name=None, + codeobj_class=None, + ): + """ + This method is used by `run_regularly` and `run_at` to register operations + that are executed at times determined by a user-supplied or internally generated `Clock`. + The resulting `CodeRunner` is automatically added to the group. + + Parameters + ---------- + code : str + The abstract code to run. + clock : `Clock`, optional + The update clock to use for this operation. If neither a clock nor + the `dt` argument is specified, defaults to the clock of the group. + when : str, optional + When to run within a time step, defaults to the ``'start'`` slot. + See :ref:`scheduling` for possible values. + name : str, optional + A unique name, if non is given the name of the group appended with + 'run_custom_clock', 'run_custom_clock_1', etc. will be used. If a + name is given explicitly, it will be used as given (i.e. the group + name will not be prepended automatically). + codeobj_class : class, optional + The `CodeObject` class to run code with. If not specified, defaults + to the `group`'s ``codeobj_class`` attribute. + + Returns + ------- + obj : `CodeRunner` + A reference to the object that will be run. + """ + if name is None: + names = [o.name for o in self.contained_objects] + name = find_name(f"{self.name}_run_custom*", names) # Subgroups are normally not included in their parent's # contained_objects list, since there's no need to include them in the @@ -1195,7 +1243,6 @@ def run_regularly( "stateupdate", code=code, name=name, - dt=dt, clock=clock, when=when, order=order, @@ -1204,6 +1251,53 @@ def run_regularly( self.contained_objects.append(runner) return runner + def run_at( + self, + code, + times, + when="start", + order=0, + name=None, + codeobj_class=None, + ): + """ + Run abstract code in the group's namespace. The created `CodeRunner` + object will be automatically added to the group, it therefore does not + need to be added to the network manually. However, a reference to the + object will be returned, which can be used to later remove it from the + group or to set it to inactive. + + Parameters + ---------- + code : str + The abstract code to run. + times : array-like + The specific simulation times at which to execute the code. + when : str, optional + When to run within a time step, defaults to the ``'start'`` slot. + See :ref:`scheduling` for possible values. + name : str, optional + A unique name, if non is given the name of the group appended with + 'run_at', 'run_at_1', etc. will be used. If a + name is given explicitly, it will be used as given (i.e. the group + name will not be prepended automatically). + codeobj_class : class, optional + The `CodeObject` class to run code with. If not specified, defaults + to the `group`'s ``codeobj_class`` attribute. + + Returns + ------- + obj : `CodeRunner` + A reference to the object that will be run. + """ + if name is None: + names = [o.name for o in self.contained_objects] + name = find_name(f"{self.name}_run_at*", names) + from brian2.core.clocks import EventClock + + clock = EventClock(times) + return self.run_on_clock(code, clock, when, order, name, codeobj_class) + def _check_for_invalid_states(self): """ Checks if any state variables updated by differential equations have diff --git a/brian2/tests/test_clocks.py b/brian2/tests/test_clocks.py index ef9249dd5..b2c0675a7 100644 --- a/brian2/tests/test_clocks.py +++ b/brian2/tests/test_clocks.py @@ -2,6 +2,9 @@ from numpy.testing import assert_array_equal, assert_equal from brian2 import * +from brian2.core.clocks import EventClock +from brian2.tests.test_network import NameLister +from brian2.units.fundamentalunits import DimensionMismatchError from brian2.utils.logger import catch_logs @@ -57,6 +60,71 @@ def test_set_interval_warning(): assert logs[0][1].endswith("many_timesteps") +@pytest.mark.codegen_independent +def test_event_clock(): + times = [0.0 * ms, 0.3 * ms, 0.5 * ms, 0.6 * ms] + event_clock = EventClock(times) + + assert_equal(event_clock.variables["t"].get_value(), 0.0 * ms) + assert_equal(event_clock[1], 0.3 * ms) + + event_clock.advance() + assert_equal(event_clock.variables["timestep"].get_value(), 1) + assert_equal(event_clock.variables["t"].get_value(), 0.0003) + + event_clock.set_interval(0.3 * ms, 0.6 * ms) + assert_equal(event_clock.variables["timestep"].get_value(), 1) + assert_equal(event_clock.variables["t"].get_value(), 0.0003) + event_clock.advance() + event_clock.advance() + + with pytest.raises(StopIteration): + event_clock.advance() + + invalid_times = [0.0 * volt, 0.5 * volt] + with pytest.raises(DimensionMismatchError) as excinfo: + EventClock(invalid_times) + + +@pytest.mark.codegen_independent +def test_combined_clocks_with_run_at(): + + # Reset updates + NameLister.updates[:] = [] + + # Regular NameLister at 1ms interval + regular_lister = NameLister(name="x", dt=1 * ms, order=0) + + # Event NameLister at specific times + event_times = [0.5 * ms, 2.5 * ms, 4 * ms] + event_lister = NameLister(name="y", clock=EventClock(times=event_times), order=1) + + # Create and run the network + net = Network(regular_lister, event_lister) + net.run(5 * ms) + + # Get update string + updates = "".join(NameLister.updates) + + # Expected output: "x" at 0,1,2,3,4ms = 5 times + # "y" at 0.5, 2.5, 4.0ms = 3 times + expected_x_count = 5 + expected_y_count = 3 + + x_count = updates.count("x") + y_count = updates.count("y") + + assert ( + x_count == expected_x_count + ), f"Expected {expected_x_count} x's, got {x_count}" + assert ( + y_count == expected_y_count + ), f"Expected {expected_y_count} y's, got {y_count}" + + expected_output = "xyxxyxxy" + assert updates == expected_output, f"Expected {expected_output}, got {updates}" + + if __name__ == "__main__": test_clock_attributes() restore_initial_state() @@ -64,3 +132,5 @@ def test_set_interval_warning(): restore_initial_state() test_defaultclock() test_set_interval_warning() + test_event_clock() + test_combined_clocks_with_run_at() diff --git a/brian2/tests/test_network.py b/brian2/tests/test_network.py index 9b623fd36..5b9385ff3 100644 --- a/brian2/tests/test_network.py +++ b/brian2/tests/test_network.py @@ -1676,12 +1676,14 @@ def test_small_runs(): # One long run and multiple small runs should give the same results group_1 = NeuronGroup(10, "dv/dt = -v / (10*ms) : 1") group_1.v = "(i + 1) / N" + group_1.run_at("v += 0.1", times=[100 * ms, 300 * ms]) mon_1 = StateMonitor(group_1, "v", record=True) net_1 = Network(group_1, mon_1) net_1.run(1 * second) group_2 = NeuronGroup(10, "dv/dt = -v / (10*ms) : 1") group_2.v = "(i + 1) / N" + group_2.run_at("v += 0.1", times=[100 * ms, 300 * ms]) mon_2 = StateMonitor(group_2, "v", record=True) net_2 = Network(group_2, mon_2) runtime = 1 * ms diff --git a/brian2/tests/test_neurongroup.py b/brian2/tests/test_neurongroup.py index 6796a14d3..166bda517 100644 --- a/brian2/tests/test_neurongroup.py +++ b/brian2/tests/test_neurongroup.py @@ -3,7 +3,7 @@ import numpy as np import pytest import sympy -from numpy.testing import assert_equal, assert_raises +from numpy.testing import assert_equal from brian2.core.base import BrianObjectException from brian2.core.clocks import defaultclock @@ -88,7 +88,7 @@ def test_variables(): assert "not_refractory" not in G.variables and "lastspike" not in G.variables G = NeuronGroup(1, "dv/dt = -v/tau + xi*tau**-0.5: 1") - assert not "tau" in G.variables and "xi" in G.variables + assert "tau" not in G.variables and "xi" in G.variables # NeuronGroup with refractoriness G = NeuronGroup(1, "dv/dt = -v/(10*ms) : 1", refractory=5 * ms) @@ -2181,6 +2181,16 @@ def test_run_regularly_dt(): assert_allclose(np.diff(M.v[0]), np.tile([0, 1], 5)[:-1]) +@pytest.mark.standalone_compatible +def test_run_at(): + G = NeuronGroup(1, "v : 1") + G.run_at("v += 1", times=[0, 1, 3] * defaultclock.dt) + M = StateMonitor(G, "v", record=0, when="end") + run(4 * defaultclock.dt) + assert_allclose(G.v[:], 3) + assert_allclose(M.v[0], [1, 2, 2, 3]) + + @pytest.mark.standalone_compatible def test_run_regularly_shared(): # Check that shared variables are handled correctly in run_regularly diff --git a/examples/standalone/simple_case.py b/examples/standalone/simple_case.py index 8b5f4e02e..45691577c 100644 --- a/examples/standalone/simple_case.py +++ b/examples/standalone/simple_case.py @@ -2,18 +2,21 @@ """ The most simple case how to use standalone mode. """ + from brian2 import * -set_device('cpp_standalone') # ← only difference to "normal" simulation -tau = 10*ms -eqs = ''' +set_device("cpp_standalone") # ← only difference to "normal" simulation + +tau = 10 * ms +eqs = """ dv/dt = (1-v)/tau : 1 -''' -G = NeuronGroup(10, eqs, method='exact') -G.v = 'rand()' -mon = StateMonitor(G, 'v', record=True) -run(100*ms) +""" +G = NeuronGroup(10, eqs, method="exact") +G.v = "rand()" +G.run_at("v += 5", times=[5, 50]*ms) +mon = StateMonitor(G, "v", record=True) +run(100 * ms) -plt.plot(mon.t/ms, mon.v.T) -plt.gca().set(xlabel='t (ms)', ylabel='v') +plt.plot(mon.t / ms, mon.v.T) +plt.gca().set(xlabel="t (ms)", ylabel="v") plt.show()