Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix delay bug #650

Merged
merged 6 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
Sequential as Sequential,
Dynamic as Dynamic, # category
Projection as Projection,
receive_update_input, # decorators
receive_update_output,
not_receive_update_input,
not_receive_update_output,
)
DynamicalSystemNS = DynamicalSystem
Network = DynSysGroup
Expand All @@ -84,7 +88,6 @@
load_state as load_state,
clear_input as clear_input)


# Part: Running #
# --------------- #
from brainpy._src.runners import (DSRunner as DSRunner)
Expand Down
42 changes: 30 additions & 12 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,21 @@
]


delay_identifier = '_*_delay_*_'
delay_identifier = '_*_delay_of_'


def _get_delay(delay_time, delay_step):
if delay_time is None:
if delay_step is None:
return None, None
else:
assert isinstance(delay_step, int), '"delay_step" should be an integer.'
delay_time = delay_step * bm.get_dt()
else:
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
return delay_time, delay_step


class Delay(DynamicalSystem, ParamDesc):
Expand Down Expand Up @@ -97,13 +111,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[float, bm.Array, Callable]],
delay_time: Optional[Union[float, bm.Array, Callable]] = None,
delay_step: Optional[int] = None
) -> 'Delay':
"""Register an entry to access the data.

Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``.

Returns:
Return the self.
Expand Down Expand Up @@ -237,13 +253,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[int, float]],
delay_time: Optional[Union[int, float]] = None,
delay_step: Optional[int] = None,
) -> 'Delay':
"""Register an entry to access the data.

Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.

Returns:
Return the self.
Expand All @@ -258,12 +276,7 @@ def register_entry(
assert delay_time.size == 1 and delay_time.ndim == 0
delay_time = delay_time.item()

if delay_time is None:
delay_step = None
delay_time = 0.
else:
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
_, delay_step = _get_delay(delay_time, delay_step)

# delay variable
if delay_step is not None:
Expand Down Expand Up @@ -354,24 +367,29 @@ def update(
"""Update delay variable with the new data.
"""
if self.data is not None:
# jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value))

# get the latest target value
if latest_value is None:
latest_value = self.target.value

# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
i = share.load('i')
idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32)
self.data[idx] = latest_value
idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32)
self.data[jax.lax.stop_gradient(idx)] = latest_value

# update the delay data at the first position
elif self.method == CONCAT_UPDATE:
if self.max_length > 1:
latest_value = bm.expand_dims(latest_value, 0)
self.data.value = bm.concat([latest_value, self.data[1:]], axis=0)
self.data.value = bm.concat([latest_value, self.data[:-1]], axis=0)
else:
self.data[0] = latest_value

else:
raise ValueError(f'Unknown updating method "{self.method}"')

def reset_state(self, batch_size: int = None, **kwargs):
"""Reset the delay data.
"""
Expand Down
25 changes: 9 additions & 16 deletions brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,7 @@ def __init__(
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.pre.register_local_delay("spike", self.name, delay_step)

def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
if self.stp is not None:
self.stp.reset_state(batch_size)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

def update(self, pre_spike=None):
# pre-synaptic spikes
Expand Down Expand Up @@ -232,7 +227,6 @@ class Exponential(TwoEndConn):
method: str
The numerical integration methods.


"""

def __init__(
Expand Down Expand Up @@ -283,17 +277,16 @@ def __init__(
else:
raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".')

# variables
self.g = self.syn.g

# delay
self.pre.register_local_delay("spike", self.name, delay_step)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

def reset_state(self, batch_size=None):
self.syn.reset_state(batch_size)
self.output.reset_state(batch_size)
if self.stp is not None:
self.stp.reset_state(batch_size)
@property
def g(self):
return self.syn.g

@g.setter
def g(self, value):
self.syn.g = value

def update(self, pre_spike=None):
# delays
Expand Down
14 changes: 2 additions & 12 deletions brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (ParamDesc, JointType,
SupportAutoDelay, BindCondData, ReturnInfo)
from brainpy._src.mixin import (ParamDesc, JointType, SupportAutoDelay, BindCondData, ReturnInfo)
from brainpy.errors import UnsupportedError
from brainpy.types import ArrayType

Expand Down Expand Up @@ -47,9 +46,6 @@ def isregistered(self, val: bool):
raise ValueError('Must be an instance of bool.')
self._registered = val

def reset_state(self, batch_size=None):
pass

def register_master(self, master: SynConn):
if not isinstance(master, SynConn):
raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}')
Expand Down Expand Up @@ -296,7 +292,7 @@ def __init__(
mode=mode)

# delay
self.pre.register_local_delay("spike", self.name, delay_step)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

# synaptic dynamics
self.syn = syn
Expand Down Expand Up @@ -340,11 +336,5 @@ def g_max(self, v):
UserWarning)
self.comm.weight = v

def reset_state(self, *args, **kwargs):
self.syn.reset(*args, **kwargs)
self.comm.reset(*args, **kwargs)
self.output.reset(*args, **kwargs)
if self.stp is not None:
self.stp.reset(*args, **kwargs)


129 changes: 119 additions & 10 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,41 @@ def __init__(

# Attribute for "SupportInputProj"
# each instance of "SupportInputProj" should have a "cur_inputs" attribute
self.current_inputs = bm.node_dict()
self.delta_inputs = bm.node_dict()
self._current_inputs: Optional[Dict[str, Callable]] = None
self._delta_inputs: Optional[Dict[str, Callable]] = None

# the before- / after-updates used for computing
# added after the version of 2.4.3
self.before_updates: Dict[str, Callable] = bm.node_dict()
self.after_updates: Dict[str, Callable] = bm.node_dict()
self._before_updates: Optional[Dict[str, Callable]] = None
self._after_updates: Optional[Dict[str, Callable]] = None

# super initialization
super().__init__(name=name)

@property
def current_inputs(self):
if self._current_inputs is None:
self._current_inputs = bm.node_dict()
return self._current_inputs

@property
def delta_inputs(self):
if self._delta_inputs is None:
self._delta_inputs = bm.node_dict()
return self._delta_inputs

@property
def before_updates(self):
if self._before_updates is None:
self._before_updates = bm.node_dict()
return self._before_updates

@property
def after_updates(self):
if self._after_updates is None:
self._after_updates = bm.node_dict()
return self._after_updates

def add_bef_update(self, key: Any, fun: Callable):
"""Add the before update into this node"""
if key in self.before_updates:
Expand Down Expand Up @@ -220,25 +244,32 @@ def register_local_delay(
self,
var_name: str,
delay_name: str,
delay: Union[numbers.Number, ArrayType] = None,
delay_time: Union[numbers.Number, ArrayType] = None,
delay_step: Union[numbers.Number, ArrayType] = None,
):
"""Register local relay at the given delay time.

Args:
var_name: str. The name of the delay target variable.
delay_name: str. The name of the current delay data.
delay: The delay time.
delay_time: The delay time. Float.
delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``.
"""
delay_identifier, init_delay_by_return = _get_delay_tool()
delay_identifier = delay_identifier + var_name
# check whether the "var_name" has been registered
try:
target = getattr(self, var_name)
except AttributeError:
raise AttributeError(f'This node {self} does not has attribute of "{var_name}".')
if not self.has_aft_update(delay_identifier):
self.add_aft_update(delay_identifier, init_delay_by_return(target))
# add a model to receive the return of the target model
# moreover, the model should not receive the return of the update function
model = not_receive_update_output(init_delay_by_return(target))
# register the model
self.add_aft_update(delay_identifier, model)
delay_cls = self.get_aft_update(delay_identifier)
delay_cls.register_entry(delay_name, delay)
delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step)

def get_local_delay(self, var_name, delay_name):
"""Get the delay at the given identifier (`name`).
Expand Down Expand Up @@ -381,14 +412,20 @@ def __call__(self, *args, **kwargs):

# ``before_updates``
for model in self.before_updates.values():
model()
if hasattr(model, '_receive_update_input'):
model(*args, **kwargs)
else:
model()

# update the model self
ret = self.update(*args, **kwargs)

# ``after_updates``
for model in self.after_updates.values():
model(ret)
if hasattr(model, '_not_receive_update_output'):
model()
else:
model(ret)
return ret

def __rrshift__(self, other):
Expand Down Expand Up @@ -832,3 +869,75 @@ def _slice_to_num(slice_: slice, length: int):
start += step
num += 1
return num


def receive_update_output(cls: object):
"""
The decorator to mark the object (as the after updates) to receive the output of the update function.

That is, the `aft_update` will receive the return of the update function::

ret = model.update(*args, **kwargs)
for fun in model.aft_updates:
fun(ret)

"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
if hasattr(cls, '_not_receive_update_output'):
delattr(cls, '_not_receive_update_output')
return cls


def not_receive_update_output(cls: object):
"""
The decorator to mark the object (as the after updates) to not receive the output of the update function.

That is, the `aft_update` will not receive the return of the update function::

ret = model.update(*args, **kwargs)
for fun in model.aft_updates:
fun()

"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
cls._not_receive_update_output = True
return cls


def receive_update_input(cls: object):
"""
The decorator to mark the object (as the before updates) to receive the input of the update function.

That is, the `bef_update` will receive the input of the update function::


for fun in model.bef_updates:
fun(*args, **kwargs)
model.update(*args, **kwargs)

"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
cls._receive_update_input = True
return cls


def not_receive_update_input(cls: object):
"""
The decorator to mark the object (as the before updates) to not receive the input of the update function.

That is, the `bef_update` will not receive the input of the update function::

for fun in model.bef_updates:
fun()
model.update()

"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
if hasattr(cls, '_receive_update_input'):
delattr(cls, '_receive_update_input')
return cls





Loading
Loading