Skip to content

Commit

Permalink
Merge pull request brainpy#523 from chaoming0625/updates
Browse files Browse the repository at this point in the history
[projection] upgrade projections so that APIs are reused across different models
  • Loading branch information
chaoming0625 authored Oct 29, 2023
2 parents e849a9a + 6e57e2b commit 06276ee
Show file tree
Hide file tree
Showing 15 changed files with 703 additions and 251 deletions.
35 changes: 31 additions & 4 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import variable_
from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE
from brainpy._src.mixin import ParamDesc, ReturnInfo
from brainpy._src.mixin import ParamDesc, ReturnInfo, JointType, SupportAutoDelay
from brainpy.check import jit_error


Expand Down Expand Up @@ -461,22 +461,33 @@ def __init__(
self,
delay: Delay,
time: Union[None, int, float],
*indices
*indices,
delay_entry: str = None
):
super().__init__(mode=delay.mode)
self.refs = {'delay': delay}
assert isinstance(delay, Delay)
delay.register_entry(self.name, time)
self._delay_entry = delay_entry or self.name
delay.register_entry(self._delay_entry, time)
self.indices = indices

def update(self):
return self.refs['delay'].at(self.name, *self.indices)
return self.refs['delay'].at(self._delay_entry, *self.indices)

def reset_state(self, *args, **kwargs):
pass


def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_data=None) -> Delay:
"""Initialize a delay class by the return info (usually is created by ``.return_info()`` function).
Args:
info: the return information.
initial_delay_data: The initial delay data.
Returns:
The decay instance.
"""
if isinstance(info, bm.Variable):
return VarDelay(info, init=initial_delay_data)

Expand Down Expand Up @@ -513,3 +524,19 @@ def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_dat
return DataDelay(target, data_init=info.data, init=initial_delay_data)
else:
raise TypeError


def register_delay_by_return(target: JointType[DynamicalSystem, SupportAutoDelay]):
"""Register delay class for the given target.
Args:
target: The target class to register delay.
Returns:
The delay registered for the given target.
"""
if not target.has_aft_update(delay_identifier):
delay_ins = init_delay_by_return(target.return_info())
target.add_aft_update(delay_identifier, delay_ins)
delay_cls = target.get_aft_update(delay_identifier)
return delay_cls
Loading

0 comments on commit 06276ee

Please sign in to comment.