Skip to content

Fix issue run at #1602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 106 additions & 22 deletions brian2/core/clocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,98 @@ def check_dt(new_dt, old_dt, target_t):
f"time {t} is not a multiple of {new}."
)

class ClockArray:
def __init__(self, clock):
self.clock = clock

def __getitem__(self, timestep):
return self.clock.dt * timestep

class EventClock(VariableOwner):
def __init__(self, times, name="eventclock*"):
Nameable.__init__(self, name=name)
self.variables = Variables(self)
self.times = times
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably sort these times (given that we assume them to be sorted later). Also, maybe include a check that it does not contain the same time twice?

self.variables.add_array(
"timestep", size=1, dtype=np.int64, read_only=True, scalar=True
)
self.variables.add_array(
"t",
dimensions=second.dim,
size=1,
dtype=np.float64,
read_only=True,
scalar=True,
)
self.variables["timestep"].set_value(0)
self.variables["t"].set_value(self.times[0])

self.variables.add_constant("N", value=1)

self._enable_group_attributes()

self._i_end = None
logger.diagnostic(f"Created clock {self.name}")

def advance(self):
"""
Advance the clock to the next timestep.
"""
current_timestep = self.variables["timestep"].get_value().item()
next_timestep = current_timestep + 1
if self._i_end is not None and next_timestep > self._i_end:
raise StopIteration("Clock has reached the end of its available times.")
else:
self.variables["timestep"].set_value(next_timestep)
self.variables["t"].set_value(self.times[next_timestep])

@check_units(start=second, end=second)
def set_interval(self, start, end):
"""
Set the start and end time of the simulation.
"""

if not isinstance(self.times, ClockArray):

class Clock(VariableOwner):
start_idx = np.searchsorted(self.times, float(start))
end_idx = np.searchsorted(self.times, float(end))

self.variables["timestep"].set_value(start_idx)
self.variables["t"].set_value(self.times[start_idx])
self._i_end = end_idx - 1
else:

pass

def __lt__(self, other):
return self.variables["t"].get_value().item() < other.variables["t"].get_value().item()

def __eq__(self, other):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I get the appeal of using __eq__ here, I don't think it is the right approach – also, it breaks everything :) The reason is that we don't want to consider two Clocks to be equal just because their time is the same. E.g. there are places where we check whether two objects use the same clock, and those would always return True at the start of a simulation when those clocks are at 0ms. Also, we have places where we put clocks into a set, and Python objects that have __eq__ but not __hash__ are not hashable (and the hash of an object is not allowed to change and needs to be consistent with __eq__ so we cannot simply add a __hash__ function. I think the best approach would be to add a specific function to compare them, say same_time(self, other)

t1 = self.variables["t"].get_value().item()
t2 = other.variables["t"].get_value().item()

if hasattr(self, 'dt'):
dt = self.variables["dt"].get_value().item()
return abs(t1 - t2) / dt < self.epsilon_dt
elif hasattr(other, 'dt'):
dt = other.variables["dt"].get_value().item()
return abs(t1 - t2) / dt < self.epsilon_dt
else:
# Both are pure EventClocks without dt
epsilon = 1e-10
return abs(t1 - t2) < epsilon

def __le__(self, other):
return self.__lt__(other) or self.__eq__(other)

def __gt__(self, other):
return not self.__le__(other)

def __ge__(self, other):
return not self.__lt__(other)


class RegularClock(EventClock):
"""
An object that holds the simulation time and the time step.

Expand All @@ -82,23 +172,13 @@ class Clock(VariableOwner):
point values. The value of ``epsilon`` is ``1e-14``.
"""

def __init__(self, dt, name="clock*"):
def __init__(self, dt, name="regularclock*"):
# We need a name right away because some devices (e.g. cpp_standalone)
# need a name for the object when creating the variables
Nameable.__init__(self, name=name)
self._old_dt = None
self.variables = Variables(self)
self.variables.add_array(
"timestep", size=1, dtype=np.int64, read_only=True, scalar=True
)
self.variables.add_array(
"t",
dimensions=second.dim,
size=1,
dtype=np.float64,
read_only=True,
scalar=True,
)
self._dt = float(dt)
self._old_dt = None
times = ClockArray(self)
super().__init__(times, name=name)
self.variables.add_array(
"dt",
dimensions=second.dim,
Expand All @@ -109,10 +189,6 @@ def __init__(self, dt, name="clock*"):
constant=True,
scalar=True,
)
self.variables.add_constant("N", value=1)
self._enable_group_attributes()
self.dt = dt
logger.diagnostic(f"Created clock {self.name} with dt={self.dt}")

@check_units(t=second)
def _set_t_update_dt(self, target_t=0 * second):
Expand All @@ -129,7 +205,10 @@ def _set_t_update_dt(self, target_t=0 * second):
# update them via the variables object directly
self.variables["timestep"].set_value(new_timestep)
self.variables["t"].set_value(new_timestep * new_dt)
logger.diagnostic(f"Setting Clock {self.name} to t={self.t}, dt={self.dt}")
# Use self.variables["t"].get_value().item() and self.variables["dt"].get_value().item() for logging
t_value = self.variables["t"].get_value().item()
dt_value = self.variables["dt"].get_value().item()
logger.diagnostic(f"Setting Clock {self.name} to t={t_value}, dt={dt_value}")

def _calc_timestep(self, target_t):
"""
Expand Down Expand Up @@ -211,6 +290,11 @@ def set_interval(self, start, end):
epsilon_dt = 1e-4


class Clock(RegularClock):
def __init__(self, dt, name="clock*"):
super().__init__(dt, name)


class DefaultClockProxy:
"""
Method proxy to access the defaultclock of the currently active device
Expand All @@ -230,4 +314,4 @@ def __setattr__(self, key, value):


#: The standard clock, used for objects that do not specify any clock or dt
defaultclock = DefaultClockProxy()
defaultclock = DefaultClockProxy()
18 changes: 7 additions & 11 deletions brian2/core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,21 +1035,17 @@ def after_run(self):
obj.after_run()

def _nextclocks(self):
clocks_times_dt = [
(c, self._clock_variables[c][1][0], self._clock_variables[c][2][0])
for c in self._clocks
]
minclock, min_time, minclock_dt = min(clocks_times_dt, key=lambda k: k[1])

minclock = min(self._clocks)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this (i.e. defining __lt__ and __le__) is less of an issue than the equality check, for consistency/readability we should probably also have an explicit function here, or simply use something along the lines of min(self._clocks, key=lambda c: c.t)


curclocks = {
clock
for clock, time, dt in clocks_times_dt
if (
time == min_time
or abs(time - min_time) / min(minclock_dt, dt) < Clock.epsilon_dt
)
for clock in self._clocks
if clock == minclock
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment about __eq__ above

}

return minclock, curclocks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to worry about before the PR is finalized, but have a look at our documentation on how to enable pre-commit so that it will make sure that the source code is formatted consistently: https://brian2.readthedocs.io/en/stable/developer/guidelines/style.html#code-style

@device_override("network_run")
@check_units(duration=second, report_period=second)
def run(
Expand Down
Loading