Skip to content

Commit

Permalink
update reset_state() function in DynamicalSystem
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 24, 2023
1 parent 5f7461b commit a713a52
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 19 deletions.
11 changes: 11 additions & 0 deletions brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,21 @@ def g_max(self, v):
self.comm.weight = v

def reset_state(self, *args, **kwargs):
self.syn.reset_bef_updates(*args, **kwargs)
self.syn.reset_state(*args, **kwargs)
self.syn.reset_aft_updates(*args, **kwargs)

self.comm.reset_bef_updates(*args, **kwargs)
self.comm.reset_state(*args, **kwargs)
self.comm.reset_aft_updates(*args, **kwargs)

self.output.reset_bef_updates(*args, **kwargs)
self.output.reset_state(*args, **kwargs)
self.output.reset_aft_updates(*args, **kwargs)

if self.stp is not None:
self.stp.reset_bef_updates(*args, **kwargs)
self.stp.reset_state(*args, **kwargs)
self.stp.reset_aft_updates(*args, **kwargs)


69 changes: 50 additions & 19 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,32 @@ def update(self, *args, **kwargs):
raise NotImplementedError('Must implement "update" function by subclass self.')

def reset(self, *args, **kwargs):
"""Reset function which resets the whole variables in the model.
"""Reset function which reset the whole variables in the model.
"""
child_nodes = self.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique()
for node in child_nodes.values():
node.reset_state(*args, **kwargs)
self.reset_bef_updates(*args, **kwargs)
self.reset_state(*args, **kwargs)
self.reset_aft_updates(*args, **kwargs)

def reset_state(self, *args, **kwargs):
"""Reset function which reset the states in the model.
The main interface for resetting the states of the model.
"""
pass
child_nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique()
if len(child_nodes) > 0:
for node in child_nodes.values():
node.reset_bef_updates(*args, **kwargs)
node.reset_state(*args, **kwargs)
node.reset_aft_updates(*args, **kwargs)
self.reset_local_delays(child_nodes)
else:
raise NotImplementedError(f'Must implement "reset_state" function by subclass self. Error of {self.name}')

def clear_input(self):
def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
pass
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
for node in nodes.values():
node.clear_input()

def step_run(self, i, *args, **kwargs):
"""The step run function.
Expand Down Expand Up @@ -393,6 +405,8 @@ def __rrshift__(self, other):
return self.__call__(other)




class DynSysGroup(DynamicalSystem, Container):
"""A group of :py:class:`~.DynamicalSystem`s in which the updating order does not matter.
Expand Down Expand Up @@ -441,35 +455,35 @@ def update(self, *args, **kwargs):
# TODO: Will be deprecated in the future
self.update_local_delays(nodes)

def reset_state(self, batch_size=None):
def reset_state(self, batch_or_mode=None):
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)

# reset projections
for node in nodes.subset(Projection).values():
node.reset_state(batch_size)
node.reset_bef_updates(batch_or_mode)
node.reset_state(batch_or_mode)
node.reset_aft_updates(batch_or_mode)

# reset dynamics
for node in nodes.subset(Dynamic).values():
node.reset_state(batch_size)
node.reset_bef_updates(batch_or_mode)
node.reset_state(batch_or_mode)
node.reset_aft_updates(batch_or_mode)

# reset other types of nodes, including delays, ...
for node in nodes.not_subset(Dynamic).not_subset(Projection).values():
node.reset_state(batch_size)
node.reset_bef_updates(batch_or_mode)
node.reset_state(batch_or_mode)
node.reset_aft_updates(batch_or_mode)

# reset
self.reset_aft_updates(batch_size)
self.reset_bef_updates(batch_size)
self.reset_aft_updates(batch_or_mode)
self.reset_bef_updates(batch_or_mode)

# reset delays
# TODO: will be removed in the future
self.reset_local_delays(nodes)

def clear_input(self):
"""Clear inputs in the children classes."""
nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().not_subset(DynView)
for node in nodes.values():
node.clear_input()


class Network(DynSysGroup):
"""A group of :py:class:`~.DynamicalSystem`s which defines the nodes and edges in a network.
Expand Down Expand Up @@ -579,7 +593,9 @@ def reset_state(self, *args, **kwargs):
nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values())
if len(nodes):
for node in nodes:
node.reset_bef_updates(*args, **kwargs)
node.reset_state(*args, **kwargs)
node.reset_aft_updates(*args, **kwargs)
else:
raise ValueError('Do not implement the reset_state() function.')

Expand All @@ -591,6 +607,14 @@ def update(self, *args, **kwargs):
else:
raise ValueError('Do not implement the update() function.')

def clear_input(self, *args, **kwargs):
"""Empty function of clearing inputs."""
pass

def reset_state(self, *args, **kwargs):
raise NotImplementedError(f'Must implement "reset_state" function by subclass self. Error of {self.name}')


class Dynamic(DynamicalSystem):
"""Base class to model dynamics.
Expand Down Expand Up @@ -701,6 +725,13 @@ def __repr__(self):
def __getitem__(self, item):
return DynView(target=self, index=item)

def clear_input(self, *args, **kwargs):
"""Empty function of clearing inputs."""
pass

def reset_state(self, *args, **kwargs):
raise NotImplementedError(f'Must implement "reset_state" function by subclass self. Error of {self.name}')


class DynView(Dynamic):
"""DSView, an object used to get a view of a dynamical system instance.
Expand Down

0 comments on commit a713a52

Please sign in to comment.