Skip to content

Commit

Permalink
upgrade reset_state() function in DynamicalSystem
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 24, 2023
1 parent a713a52 commit d818bdc
Showing 1 changed file with 24 additions and 32 deletions.
56 changes: 24 additions & 32 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from brainpy._src.deprecations import _update_deprecate_msg
from brainpy._src.context import share


__all__ = [
# general
'DynamicalSystem',
Expand Down Expand Up @@ -188,22 +187,31 @@ def update(self, *args, **kwargs):

def reset(self, *args, **kwargs):
"""Reset function which reset the whole variables in the model.
``reset()`` function is a collective behavior which resets states in the current node,
nodes in ``before_updates``, and nodes in ``after_updates``.
"""
self.reset_bef_updates(*args, **kwargs)
self.reset_state(*args, **kwargs)
self.reset_aft_updates(*args, **kwargs)

def reset_state(self, *args, **kwargs):
"""Reset function which reset the states in the model.
"""Reset function which resets the states in the model.
The main interface for resetting the states of the model.
If the model behaves like a gather or collector, it will rest all states
(by calling ``reset()`` function) in children nodes.
If the model behaves as a single module, it requires users to implement this
rest function.
Simply speaking, this function should implement the logic of resetting of
local variables in this node.
"""
child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
if len(child_nodes) > 0:
for node in child_nodes.values():
node.reset_bef_updates(*args, **kwargs)
node.reset_state(*args, **kwargs)
node.reset_aft_updates(*args, **kwargs)
node.reset(*args, **kwargs)
self.reset_local_delays(child_nodes)
else:
raise NotImplementedError(f'Must implement "reset_state" function by subclass self. Error of {self.name}')
Expand Down Expand Up @@ -405,8 +413,6 @@ def __rrshift__(self, other):
return self.__call__(other)




class DynSysGroup(DynamicalSystem, Container):
"""A group of :py:class:`~.DynamicalSystem`s in which the updating order does not matter.
Expand Down Expand Up @@ -460,25 +466,15 @@ def reset_state(self, batch_or_mode=None):

# reset projections
for node in nodes.subset(Projection).values():
node.reset_bef_updates(batch_or_mode)
node.reset_state(batch_or_mode)
node.reset_aft_updates(batch_or_mode)
node.reset(batch_or_mode)

# reset dynamics
for node in nodes.subset(Dynamic).values():
node.reset_bef_updates(batch_or_mode)
node.reset_state(batch_or_mode)
node.reset_aft_updates(batch_or_mode)
node.reset(batch_or_mode)

# reset other types of nodes, including delays, ...
for node in nodes.not_subset(Dynamic).not_subset(Projection).values():
node.reset_bef_updates(batch_or_mode)
node.reset_state(batch_or_mode)
node.reset_aft_updates(batch_or_mode)

# reset
self.reset_aft_updates(batch_or_mode)
self.reset_bef_updates(batch_or_mode)
node.reset(batch_or_mode)

# reset delays
# TODO: will be removed in the future
Expand Down Expand Up @@ -589,31 +585,27 @@ def __repr__(self):


class Projection(DynamicalSystem):
def reset_state(self, *args, **kwargs):

def update(self, *args, **kwargs):
nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values())
if len(nodes):
for node in nodes:
node.reset_bef_updates(*args, **kwargs)
node.reset_state(*args, **kwargs)
node.reset_aft_updates(*args, **kwargs)
node.update(*args, **kwargs)
else:
raise ValueError('Do not implement the reset_state() function.')
raise ValueError('Do not implement the update() function.')

def update(self, *args, **kwargs):
def reset_state(self, *args, **kwargs):
nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values())
if len(nodes):
for node in nodes:
node.update(*args, **kwargs)
node.reset(*args, **kwargs)
else:
raise ValueError('Do not implement the update() function.')
raise ValueError('Do not implement the reset_state() function.')

def clear_input(self, *args, **kwargs):
"""Empty function of clearing inputs."""
pass

def reset_state(self, *args, **kwargs):
raise NotImplementedError(f'Must implement "reset_state" function by subclass self. Error of {self.name}')


class Dynamic(DynamicalSystem):
"""Base class to model dynamics.
Expand Down

0 comments on commit d818bdc

Please sign in to comment.