From 4d8e1649921abd059ea193380a1790bfa2e8f08a Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 25 Oct 2023 09:47:14 +0800 Subject: [PATCH 1/7] add `calculate_gain` --- brainpy/initialize.py | 3 +++ docs/apis/initialize.rst | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/brainpy/initialize.py b/brainpy/initialize.py index f8cbaaee3..d2e946527 100644 --- a/brainpy/initialize.py +++ b/brainpy/initialize.py @@ -16,6 +16,9 @@ ) +from brainpy._src.initialize.random_inits import ( + calculate_gain, +) from brainpy._src.initialize.random_inits import ( Normal as Normal, Uniform as Uniform, diff --git a/docs/apis/initialize.rst b/docs/apis/initialize.rst index fcce922c8..f516aa5b5 100644 --- a/docs/apis/initialize.rst +++ b/docs/apis/initialize.rst @@ -8,6 +8,8 @@ :local: :depth: 1 + + Basic Initialization Classes ---------------------------- @@ -66,3 +68,12 @@ Decay Initializers DOGDecay +Helper Functions +---------------- + + +.. autosummary:: + :toctree: generated/ + + calculate_gain + From 3b0b800566261dcb5824280f96f453b7fb39544a Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 27 Oct 2023 20:03:41 +0800 Subject: [PATCH 2/7] compatible with jax>=0.4.16 --- brainpy/_src/math/object_transform/autograd.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 299ed4202..6122f6cd8 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -6,7 +6,12 @@ import jax import numpy as np -from jax import linear_util, dtypes, vmap, numpy as jnp, core +if jax.__version__ >= '0.4.16': + from jax.extend import linear_util +else: + from jax import linear_util + +from jax import dtypes, vmap, numpy as jnp, core from jax._src.api import (_vjp, _jvp) from jax.api_util import argnums_partial from jax.interpreters import xla From 6fdfa427a26ba92065f95d020713b08ab21af23d Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 27 Oct 2023 20:04:09 +0800 Subject: [PATCH 3/7] updates --- brainpy/_src/math/random.py | 4 +--- brainpy/dyn/neurons.py | 4 ++++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index b2b6017c9..964d3f51e 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -57,12 +57,10 @@ def _formalize_key(key): def _size2shape(size): if size is None: return () - elif isinstance(size, int): - return (size,) elif isinstance(size, (tuple, list)): return tuple(size) else: - raise ValueError(f'Must be a list/tuple of int, but got {size}') + return (size, ) def _check_shape(name, shape, *param_shapes): diff --git a/brainpy/dyn/neurons.py b/brainpy/dyn/neurons.py index c8304c875..26b9fb1d1 100644 --- a/brainpy/dyn/neurons.py +++ b/brainpy/dyn/neurons.py @@ -1,5 +1,9 @@ +from brainpy._src.dyn.neurons.base import ( + GradNeuDyn, +) + from brainpy._src.dyn.neurons.lif import ( Lif, LifLTC, From b1391bfe6dfde14777683af5587e587dbaba9d86 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 28 Oct 2023 18:21:48 +0800 Subject: [PATCH 4/7] [projection] upgrade projections so that APIs are reused across different models --- brainpy/_src/delay.py | 32 +- brainpy/_src/dyn/projections/aligns.py | 242 +++++------ brainpy/_src/dyn/projections/plasticity.py | 124 +++--- .../_src/dyn/projections/tests/test_aligns.py | 410 ++++++++++++++++++ brainpy/_src/dynsys.py | 3 + brainpy/_src/math/modes.py | 9 + brainpy/_src/math/object_transform/base.py | 13 +- brainpy/_src/mixin.py | 12 +- brainpy/_src/tests/test_mixin.py | 12 +- brainpy/mixin.py | 1 - 10 files changed, 628 insertions(+), 230 deletions(-) create mode 100644 brainpy/_src/dyn/projections/tests/test_aligns.py diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index d6cdfd682..086a1ba87 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -16,7 +16,7 @@ from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import variable_ from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE -from brainpy._src.mixin import ParamDesc, ReturnInfo +from brainpy._src.mixin import ParamDesc, ReturnInfo, JointType, SupportAutoDelay from brainpy.check import jit_error @@ -461,12 +461,13 @@ def __init__( self, delay: Delay, time: Union[None, int, float], - *indices + *indices, + delay_entry: str = None ): super().__init__(mode=delay.mode) self.refs = {'delay': delay} assert isinstance(delay, Delay) - delay.register_entry(self.name, time) + delay.register_entry(delay_entry or self.name, time) self.indices = indices def update(self): @@ -477,6 +478,15 @@ def reset_state(self, *args, **kwargs): def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_data=None) -> Delay: + """Initialize a delay class by the return info (usually is created by ``.return_info()`` function). + + Args: + info: the return information. + initial_delay_data: The initial delay data. + + Returns: + The decay instance. + """ if isinstance(info, bm.Variable): return VarDelay(info, init=initial_delay_data) @@ -513,3 +523,19 @@ def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_dat return DataDelay(target, data_init=info.data, init=initial_delay_data) else: raise TypeError + + +def register_delay_by_return(target: JointType[DynamicalSystem, SupportAutoDelay]): + """Register delay class for the given target. + + Args: + target: The target class to register delay. + + Returns: + The delay registered for the given target. + """ + if not target.has_aft_update(delay_identifier): + delay_ins = init_delay_by_return(target.return_info()) + target.add_aft_update(delay_identifier, delay_ins) + delay_cls = target.get_aft_update(delay_identifier) + return delay_cls diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index c19f45844..d8c5a4d47 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -1,9 +1,10 @@ from typing import Optional, Callable, Union from brainpy import math as bm, check -from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return +from brainpy._src.delay import (Delay, DelayAccess, delay_identifier, + init_delay_by_return, register_delay_by_return) from brainpy._src.dynsys import DynamicalSystem, Projection -from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo, +from brainpy._src.mixin import (JointType, ParamDescriber, ReturnInfo, SupportAutoDelay, BindCondData, AlignPost) __all__ = [ @@ -15,6 +16,64 @@ ] +def get_post_repr(out_label, syn, out): + return f'{out_label} // {syn.identifier} // {out.identifier}' + + +def add_inp_fun(out_label, proj_name, out, post): + # synapse and output initialization + if out_label is None: + out_name = proj_name + else: + out_name = f'{out_label} // {proj_name}' + post.add_inp_fun(out_name, out) + + +def align_post_init_bef_update(out_label, syn_desc, out_desc, post, proj_name): + # synapse and output initialization + _post_repr = get_post_repr(out_label, syn_desc, out_desc) + if not post.has_bef_update(_post_repr): + syn_cls = syn_desc() + out_cls = out_desc() + + # synapse and output initialization + if out_label is None: + out_name = proj_name + else: + out_name = f'{out_label} // {proj_name}' + post.add_inp_fun(out_name, out_cls) + post.add_bef_update(_post_repr, _AlignPost(syn_cls, out_cls)) + syn = post.get_bef_update(_post_repr).syn + out = post.get_bef_update(_post_repr).out + return syn, out + + +def align_pre2_add_bef_update(syn_desc, delay, delay_cls, proj_name=None): + _syn_id = f'Delay({str(delay)}) // {syn_desc.identifier}' + if not delay_cls.has_bef_update(_syn_id): + # delay + delay_access = DelayAccess(delay_cls, delay, delay_entry=proj_name) + # synapse + syn_cls = syn_desc() + # add to "after_updates" + delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) + syn = delay_cls.get_bef_update(_syn_id).syn + return syn + + +def align_pre1_add_bef_update(syn_desc, pre): + _syn_id = f'{syn_desc.identifier} // Delay' + if not pre.has_aft_update(_syn_id): + # "syn_cls" needs an instance of "ProjAutoDelay" + syn_cls: SupportAutoDelay = syn_desc() + delay_cls = init_delay_by_return(syn_cls.return_info()) + # add to "after_updates" + pre.add_aft_update(_syn_id, _AlignPre(syn_cls, delay_cls)) + delay_cls: Delay = pre.get_aft_update(_syn_id).delay + syn = pre.get_aft_update(_syn_id).syn + return delay_cls, syn + + class _AlignPre(DynamicalSystem): def __init__(self, syn, delay=None): super().__init__() @@ -141,9 +200,6 @@ def update(self, x): self.refs['out'].bind_cond(current) return current - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPostMg1(Projection): r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -197,8 +253,8 @@ def update(self, input): def __init__( self, comm: DynamicalSystem, - syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]], - out: ParamDescInit[JointType[DynamicalSystem, BindCondData]], + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], post: DynamicalSystem, out_label: Optional[str] = None, name: Optional[str] = None, @@ -208,27 +264,18 @@ def __init__( # synaptic models check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and output initialization - self._post_repr = f'{out_label} // {syn.identifier} // {out.identifier}' - if not post.has_bef_update(self._post_repr): - syn_cls = syn() - out_cls = out() - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out_cls) - post.add_bef_update(self._post_repr, _AlignPost(syn_cls, out_cls)) + syn, out = align_post_init_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) # references self.refs = dict(post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = post.get_bef_update(self._post_repr).syn - self.refs['out'] = post.get_bef_update(self._post_repr).out + self.refs['syn'] = syn + self.refs['out'] = out self.refs['comm'] = comm # unify the access def update(self, x): @@ -236,9 +283,6 @@ def update(self, x): self.refs['syn'].add_current(current) # synapse post current return current - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPostMg2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -315,8 +359,8 @@ def __init__( pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], comm: DynamicalSystem, - syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]], - out: ParamDescInit[JointType[DynamicalSystem, BindCondData]], + syn: ParamDescriber[JointType[DynamicalSystem, AlignPost]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], post: DynamicalSystem, out_label: Optional[str] = None, name: Optional[str] = None, @@ -327,36 +371,22 @@ def __init__( # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) check.is_instance(comm, DynamicalSystem) - check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]]) - check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, AlignPost]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) check.is_instance(post, DynamicalSystem) self.comm = comm # delay initialization - if not pre.has_aft_update(delay_identifier): - # pre should support "ProjAutoDelay" - delay_cls = init_delay_by_return(pre.return_info()) - # add to "after_updates" - pre.add_aft_update(delay_identifier, delay_cls) - delay_cls: Delay = pre.get_aft_update(delay_identifier) + delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # synapse and output initialization - self._post_repr = f'{out_label} // {syn.identifier} // {out.identifier}' - if not post.has_bef_update(self._post_repr): - syn_cls = syn() - out_cls = out() - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out_cls) - post.add_bef_update(self._post_repr, _AlignPost(syn_cls, out_cls)) + syn, out = align_post_init_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) # references self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` - self.refs['syn'] = post.get_bef_update(self._post_repr).syn # invisible to ``self.node()`` - self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()`` + self.refs['syn'] = syn # invisible to ``self.node()`` + self.refs['out'] = out # invisible to ``self.node()`` # unify the access self.refs['comm'] = comm self.refs['delay'] = pre.get_aft_update(delay_identifier) @@ -367,9 +397,6 @@ def update(self): self.refs['syn'].add_current(current) # synapse post current return current - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPost1(Projection): """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -433,32 +460,27 @@ def __init__( check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) self.comm = comm + self.syn = syn + self.out = out # synapse and output initialization - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out) - post.add_bef_update(self.name, _AlignPost(syn, out)) + add_inp_fun(out_label, self.name, out, post) # reference self.refs = dict() # invisible to ``self.nodes()`` self.refs['post'] = post - self.refs['syn'] = post.get_bef_update(self.name).syn - self.refs['out'] = post.get_bef_update(self.name).out + self.refs['syn'] = syn + self.refs['out'] = out # unify the access self.refs['comm'] = comm def update(self, x): current = self.comm(x) - self.refs['syn'].add_current(current) + g = self.syn(self.comm(x)) + self.refs['out'].bind_cond(g) # synapse post current return current - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPost2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -550,20 +572,11 @@ def __init__( self.syn = syn # delay initialization - if not pre.has_aft_update(delay_identifier): - # pre should support "ProjAutoDelay" - delay_cls = init_delay_by_return(pre.return_info()) - # add to "after_updates" - pre.add_aft_update(delay_identifier, delay_cls) - delay_cls: Delay = pre.get_aft_update(delay_identifier) + delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # synapse and output initialization - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out) + add_inp_fun(out_label, self.name, out, post) # references self.refs = dict() @@ -572,19 +585,16 @@ def __init__( self.refs['post'] = post self.refs['out'] = out # unify the access - self.refs['delay'] = pre.get_aft_update(delay_identifier) + self.refs['delay'] = delay_cls self.refs['comm'] = comm self.refs['syn'] = syn def update(self): - x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) + x = self.refs['delay'].at(self.name) g = self.syn(self.comm(x)) self.refs['out'].bind_cond(g) # synapse post current return g - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPreMg1(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -655,7 +665,7 @@ def update(self, inp): def __init__( self, pre: DynamicalSystem, - syn: ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]], + syn: ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]], delay: Union[None, int, float], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], @@ -668,29 +678,18 @@ def __init__( # synaptic models check.is_instance(pre, DynamicalSystem) - check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]]) + check.is_instance(syn, ParamDescriber[JointType[DynamicalSystem, SupportAutoDelay]]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and delay initialization - self._syn_id = f'{syn.identifier} // Delay' - if not pre.has_aft_update(self._syn_id): - # "syn_cls" needs an instance of "ProjAutoDelay" - syn_cls: SupportAutoDelay = syn() - delay_cls = init_delay_by_return(syn_cls.return_info()) - # add to "after_updates" - pre.add_aft_update(self._syn_id, _AlignPre(syn_cls, delay_cls)) - delay_cls: Delay = pre.get_aft_update(self._syn_id).delay + delay_cls, syn_cls = align_pre1_add_bef_update(syn, pre) delay_cls.register_entry(self.name, delay) # output initialization - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out) + add_inp_fun(out_label, self.name, out, post) # references self.refs = dict() @@ -699,7 +698,7 @@ def __init__( self.refs['post'] = post self.refs['out'] = out self.refs['delay'] = delay_cls - self.refs['syn'] = pre.get_aft_update(self._syn_id).syn + self.refs['syn'] = syn_cls # unify the access self.refs['comm'] = comm @@ -710,9 +709,6 @@ def update(self, x=None): self.refs['out'].bind_cond(current) return current - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPreMg2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -784,7 +780,7 @@ def __init__( self, pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], - syn: ParamDescInit[DynamicalSystem], + syn: ParamDescriber[DynamicalSystem], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], post: DynamicalSystem, @@ -796,41 +792,27 @@ def __init__( # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, ParamDescInit[DynamicalSystem]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) check.is_instance(post, DynamicalSystem) self.comm = comm # delay initialization - if not pre.has_aft_update(delay_identifier): - delay_ins = init_delay_by_return(pre.return_info()) - pre.add_aft_update(delay_identifier, delay_ins) - delay_cls = pre.get_aft_update(delay_identifier) + delay_cls = register_delay_by_return(pre) # synapse initialization - self._syn_id = f'Delay({str(delay)}) // {syn.identifier}' - if not delay_cls.has_bef_update(self._syn_id): - # delay - delay_access = DelayAccess(delay_cls, delay) - # synapse - syn_cls = syn() - # add to "after_updates" - delay_cls.add_bef_update(self._syn_id, _AlignPreMg(delay_access, syn_cls)) + syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) # output initialization - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out) + add_inp_fun(out_label, self.name, out, post) # references self.refs = dict() # invisible to `self.nodes()` self.refs['pre'] = pre self.refs['post'] = post - self.refs['syn'] = delay_cls.get_bef_update(self._syn_id).syn + self.refs['syn'] = syn_cls self.refs['out'] = out # unify the access self.refs['comm'] = comm @@ -841,9 +823,6 @@ def update(self): self.refs['out'].bind_cond(current) return current - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPre1(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -939,11 +918,7 @@ def __init__( pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) # output initialization - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out) + add_inp_fun(out_label, self.name, out, post) # references self.refs = dict() @@ -963,9 +938,6 @@ def update(self, x=None): self.refs['out'].bind_cond(current) return current - def reset_state(self, *args, **kwargs): - pass - class ProjAlignPre2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -1057,18 +1029,11 @@ def __init__( self.syn = syn # delay initialization - if not pre.has_aft_update(delay_identifier): - delay_ins = init_delay_by_return(pre.return_info()) - pre.add_aft_update(delay_identifier, delay_ins) - delay_cls = pre.get_aft_update(delay_identifier) + delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # output initialization - if out_label is None: - out_name = self.name - else: - out_name = f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out) + add_inp_fun(out_label, self.name, out, post) # references self.refs = dict() @@ -1076,7 +1041,7 @@ def __init__( self.refs['pre'] = pre self.refs['post'] = post self.refs['out'] = out - self.refs['delay'] = pre.get_aft_update(delay_identifier) + self.refs['delay'] = delay_cls # unify the access self.refs['syn'] = syn self.refs['comm'] = comm @@ -1086,6 +1051,3 @@ def update(self): g = self.comm(self.syn(spk)) self.refs['out'].bind_cond(g) return g - - def reset_state(self, *args, **kwargs): - pass diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 7c176c125..e06037273 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -1,21 +1,38 @@ from typing import Optional, Callable, Union from brainpy import math as bm, check -from brainpy._src.delay import DelayAccess, delay_identifier, init_delay_by_return +from brainpy._src.delay import register_delay_by_return from brainpy._src.dyn.synapses.abstract_models import Expon from brainpy._src.dynsys import DynamicalSystem, Projection from brainpy._src.initialize import parameter -from brainpy._src.mixin import (JointType, ParamDescInit, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP) +from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, + BindCondData, AlignPost, SupportSTDP) from brainpy.types import ArrayType -from .aligns import _AlignPost, _AlignPreMg, _get_return +from .aligns import (_get_return, align_post_init_bef_update, + align_pre2_add_bef_update, add_inp_fun) __all__ = [ 'STDP_Song2000', ] +def _init_trace_by_align_pre2( + target: DynamicalSystem, + delay: Union[None, int, float], + syn: ParamDescriber[DynamicalSystem], +): + """Calculate the trace of the target by reusing the existing connections.""" + check.is_instance(target, DynamicalSystem) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) + # delay initialization + delay_cls = register_delay_by_return(target) + # synapse initialization + syn = align_pre2_add_bef_update(syn, delay, delay_cls) + return syn + + class STDP_Song2000(Projection): - r"""Synaptic output with spike-time-dependent plasticity. + r"""Spike-time-dependent plasticity proposed by (Song, et. al, 2000). This model filters the synaptic currents according to the variables: :math:`w`. @@ -93,15 +110,23 @@ def run(i, I_pre, I_post): tau_t: float, ArrayType, Callable. The time constant of :math:`A_{post}`. A1: float, ArrayType, Callable. The increment of :math:`A_{pre}` produced by a spike. A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike. + pre: DynamicalSystem. The pre-synaptic neuron group. + delay: int, float. The pre spike delay length. (ms) + syn: DynamicalSystem. The synapse model. + comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. + out: DynamicalSystem. The synaptic current output models. + post: DynamicalSystem. The post-synaptic neuron group. + out_label: str. The output label. + name: str. The model name. """ def __init__( self, pre: JointType[DynamicalSystem, SupportAutoDelay], delay: Union[None, int, float], - syn: ParamDescInit[DynamicalSystem], + syn: ParamDescriber[DynamicalSystem], comm: JointType[DynamicalSystem, SupportSTDP], - out: ParamDescInit[JointType[DynamicalSystem, BindCondData]], + out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], post: DynamicalSystem, # synapse parameters tau_s: Union[float, ArrayType, Callable] = 16.8, @@ -117,9 +142,9 @@ def __init__( # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, ParamDescInit[DynamicalSystem]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) - check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]]) + check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) check.is_instance(post, DynamicalSystem) self.pre_num = pre.num self.post_num = post.num @@ -127,46 +152,33 @@ def __init__( self.syn = syn # delay initialization - if not pre.has_aft_update(delay_identifier): - delay_ins = init_delay_by_return(pre.return_info()) - pre.add_aft_update(delay_identifier, delay_ins) - delay_cls = pre.get_aft_update(delay_identifier) + delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) if issubclass(syn.cls, AlignPost): # synapse and output initialization - self._post_repr = f'{out_label} // {syn.identifier} // {out.identifier}' - if not post.has_bef_update(self._post_repr): - syn_cls = syn() - out_cls = out() - out_name = self.name if out_label is None else f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out_cls) - post.add_bef_update(self._post_repr, _AlignPost(syn_cls, out_cls)) + syn, out = align_post_init_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) # references self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` - self.refs['delay'] = pre.get_aft_update(delay_identifier) - self.refs['syn'] = post.get_bef_update(self._post_repr).syn # invisible to ``self.node()`` - self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()`` + self.refs['delay'] = delay_cls + self.refs['syn'] = syn # invisible to ``self.node()`` + self.refs['out'] = out # invisible to ``self.node()`` else: # synapse initialization - self._syn_id = f'Delay({str(delay)}) // {syn.identifier}' - if not delay_cls.has_bef_update(self._syn_id): - delay_access = DelayAccess(delay_cls, delay) - syn_cls = syn() - delay_cls.add_bef_update(self._syn_id, _AlignPreMg(delay_access, syn_cls)) + syn = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) # output initialization - out_name = self.name if out_label is None else f'{out_label} // {self.name}' - post.add_inp_fun(out_name, out) + add_inp_fun(out_label, self.name, out(), post) # references self.refs = dict(pre=pre, post=post) # invisible to `self.nodes()` - self.refs['delay'] = delay_cls.get_bef_update(self._syn_id) - self.refs['syn'] = delay_cls.get_bef_update(self._syn_id).syn + self.refs['delay'] = delay_cls + self.refs['syn'] = syn self.refs['out'] = out - # trace initialization - self.refs['pre_trace'] = self._init_trace(pre, delay, Expon.desc(pre.num, tau=tau_s)) - self.refs['post_trace'] = self._init_trace(post, None, Expon.desc(post.num, tau=tau_t)) + # tracing pre-synaptic spikes using Exponential model + self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) + # tracing post-synaptic spikes using Exponential model + self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) # synapse parameters self.tau_s = parameter(tau_s, sizes=self.pre_num) @@ -174,48 +186,20 @@ def __init__( self.A1 = parameter(A1, sizes=self.pre_num) self.A2 = parameter(A2, sizes=self.post_num) - def reset_state(self, *args, **kwargs): - pass - - def _init_trace( - self, - target: DynamicalSystem, - delay: Union[None, int, float], - syn: ParamDescInit[DynamicalSystem], - ): - """Calculate the trace of the target.""" - check.is_instance(target, DynamicalSystem) - check.is_instance(syn, ParamDescInit[DynamicalSystem]) - - # delay initialization - if not target.has_aft_update(delay_identifier): - delay_ins = init_delay_by_return(target.return_info()) - target.add_aft_update(delay_identifier, delay_ins) - delay_cls = target.get_aft_update(delay_identifier) - delay_cls.register_entry(target.name, delay) - - # synapse initialization - _syn_id = f'Delay({str(delay)}) // {syn.identifier}' - if not delay_cls.has_bef_update(_syn_id): - # delay - delay_access = DelayAccess(delay_cls, delay) - # synapse - syn_cls = syn() - # add to "after_updates" - delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls)) - - return delay_cls.get_bef_update(_syn_id).syn - def update(self): - # pre spikes, and pre-synaptic variables + # pre-synaptic spikes + pre_spike = self.refs['delay'].at(self.name) # spike + # pre-synaptic variables if issubclass(self.syn.cls, AlignPost): - pre_spike = self.refs['delay'].at(self.name) + # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance x = pre_spike else: - pre_spike = self.refs['delay'].access() - x = _get_return(self.refs['syn'].return_info()) + # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance + x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable # post spikes + if not hasattr(self.refs['post'], 'spike'): + raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') post_spike = self.refs['post'].spike # weight updates diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py new file mode 100644 index 000000000..600d82c8e --- /dev/null +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -0,0 +1,410 @@ +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + +neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + + +def test_ProjAlignPreMg1(): + class EICOBA_PreAlign(bp.DynamicalSystem): + def __init__(self, scale=1., inp=20.): + super().__init__() + + self.inp = inp + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2I = bp.dyn.ProjAlignPreMg1( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=None, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.I, + ) + self.E2E = bp.dyn.ProjAlignPreMg1( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=None, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.E, + ) + self.I2E = bp.dyn.ProjAlignPreMg1( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=None, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.ProjAlignPreMg1( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=None, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PreAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPostMg2(): + class EICOBA_PostAlign(bp.DynamicalSystem): + def __init__(self, scale, inp=20., ltc=True): + super().__init__() + self.inp = inp + + if ltc: + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + else: + self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2E = bp.dyn.ProjAlignPostMg2( + pre=self.E, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E, + ) + self.E2I = bp.dyn.ProjAlignPostMg2( + pre=self.E, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I, + ) + self.I2E = bp.dyn.ProjAlignPostMg2( + pre=self.I, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.ProjAlignPostMg2( + pre=self.I, + delay=None, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PostAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PostAlign(0.5, ltc=False) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPost1(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.num_exc = int(3200 * scale) + self.num_inh = num - self.num_exc + prob = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), + syn=bp.dyn.Expon(size=num, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.ProjAlignPost1(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), + syn=bp.dyn.Expon(size=num, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:self.num_exc]) + self.I(spk[self.num_exc:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet(0.5) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPost2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (ne + ni) + + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, + delay=0.1, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(0.5) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_VanillaProj(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=0.5): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg1_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=0.1, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, + delay=0.1, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_vanalla_proj_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 1.)) + self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N + ) + self.I = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N + ) + + def update(self, input): + spk = self.delay.at('delay') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) + bp.visualize.raster_plot(indices, spks, show=True) + plt.close() + bm.clear_buffer_memory() + diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index d85c16d9c..db3d574ae 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -556,6 +556,9 @@ def clear_input(self, *args, **kwargs): """Empty function of clearing inputs.""" pass + def reset_state(self, *args, **kwargs): + pass + class Dynamic(DynamicalSystem): """Base class to model dynamics. diff --git a/brainpy/_src/math/modes.py b/brainpy/_src/math/modes.py index 674035e18..d46afc248 100644 --- a/brainpy/_src/math/modes.py +++ b/brainpy/_src/math/modes.py @@ -52,6 +52,15 @@ def is_child_of(self, *modes): raise TypeError(f'The supported type must be a tuple/list of type. But we got {m_}') return isinstance(self, modes) + def is_batch_mode(self): + return isinstance(self, BatchingMode) + + def is_train_mode(self): + return isinstance(self, TrainingMode) + + def is_nonbatch_mode(self): + return isinstance(self, NonBatchingMode) + class NonBatchingMode(Mode): """Normal non-batching mode. diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index cea3414ab..f265093af 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -478,7 +478,7 @@ def unique_name(self, name=None, type_=None): check_name_uniqueness(name=name, obj=self) return name - def __save_state__(self, **kwargs) -> Dict[str, Variable]: + def __save_state__(self, **kwargs) -> Dict: """Save states. """ return self.vars(include_self=True, level=0).unique().dict() @@ -719,11 +719,12 @@ class NodeDict(dict): # raise TypeError(f'Element should be {BrainPyObject.__name__}, but got {type(elem)}.') # return elem - def __init__(self, *args, **kwargs): + def __init__(self, *args, check_unique: bool = False, **kwargs): super().__init__() self.update(*args, **kwargs) + self.check_unique = check_unique - def update(self, *args, **kwargs) -> 'VarDict': + def update(self, *args, **kwargs) -> 'NodeDict': for arg in args: if isinstance(arg, dict): for k, v in arg.items(): @@ -735,7 +736,11 @@ def update(self, *args, **kwargs) -> 'VarDict': self[k] = v return self - def __setitem__(self, key, value) -> 'VarDict': + def __setitem__(self, key, value) -> 'NodeDict': + if self.check_unique: + exist = self.get(key, None) + if id(exist) != id(value): + raise KeyError(f'Duplicate usage of key "{key}". "{key}" has been used for {value}.') super().__setitem__(key, value) return self diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 39c3ace6b..177b60aa6 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -25,7 +25,7 @@ __all__ = [ 'MixIn', 'ParamDesc', - 'ParamDescInit', + 'ParamDescriber', 'DelayRegister', 'AlignPost', 'Container', @@ -74,11 +74,11 @@ class ParamDesc(MixIn): not_desc_params: Optional[Sequence[str]] = None @classmethod - def desc(cls, *args, **kwargs) -> 'ParamDescInit': - return ParamDescInit(cls, *args, **kwargs) + def desc(cls, *args, **kwargs) -> 'ParamDescriber': + return ParamDescriber(cls, *args, **kwargs) -class ParamDescInit(object): +class ParamDescriber(object): """Delayed initialization for parameter describers. """ @@ -115,7 +115,7 @@ def init(self, *args, **kwargs): return self.__call__(*args, **kwargs) def __instancecheck__(self, instance): - if not isinstance(instance, ParamDescInit): + if not isinstance(instance, ParamDescriber): return False if not issubclass(instance.cls, self.cls): return False @@ -123,7 +123,7 @@ def __instancecheck__(self, instance): @classmethod def __class_getitem__(cls, item: type): - return ParamDescInit(item) + return ParamDescriber(item) @property def identifier(self): diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py index 5fbab7b9f..962b76cb9 100644 --- a/brainpy/_src/tests/test_mixin.py +++ b/brainpy/_src/tests/test_mixin.py @@ -7,13 +7,13 @@ class TestParamDesc(unittest.TestCase): def test1(self): a = bp.dyn.Expon(1) - self.assertTrue(not isinstance(a, bp.mixin.ParamDescInit[bp.dyn.Expon])) - self.assertTrue(not isinstance(a, bp.mixin.ParamDescInit[bp.DynamicalSystem])) + self.assertTrue(not isinstance(a, bp.mixin.ParamDescriber[bp.dyn.Expon])) + self.assertTrue(not isinstance(a, bp.mixin.ParamDescriber[bp.DynamicalSystem])) def test2(self): a = bp.dyn.Expon.desc(1) - self.assertTrue(isinstance(a, bp.mixin.ParamDescInit[bp.dyn.Expon])) - self.assertTrue(isinstance(a, bp.mixin.ParamDescInit[bp.DynamicalSystem])) + self.assertTrue(isinstance(a, bp.mixin.ParamDescriber[bp.dyn.Expon])) + self.assertTrue(isinstance(a, bp.mixin.ParamDescriber[bp.DynamicalSystem])) class TestJointType(unittest.TestCase): @@ -26,8 +26,8 @@ def test1(self): def test2(self): T = bp.mixin.JointType[bp.DynamicalSystem, bp.mixin.ParamDesc] - self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDescInit[T])) - self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDescInit[T])) + self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDescriber[T])) + self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDescriber[T])) class TestDelayRegister(unittest.TestCase): diff --git a/brainpy/mixin.py b/brainpy/mixin.py index 232fd744e..9b56befa9 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -3,7 +3,6 @@ MixIn as MixIn, AlignPost as AlignPost, ParamDesc as ParamDesc, - ParamDescInit as ParamDescInit, BindCondData as BindCondData, Container as Container, TreeNode as TreeNode, From 969848efd10ca4b30bd8bd97c619cf12c629e33f Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 28 Oct 2023 18:43:00 +0800 Subject: [PATCH 5/7] fix bug --- brainpy/_src/math/object_transform/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index f265093af..5ddbfad09 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -721,8 +721,8 @@ class NodeDict(dict): def __init__(self, *args, check_unique: bool = False, **kwargs): super().__init__() - self.update(*args, **kwargs) self.check_unique = check_unique + self.update(*args, **kwargs) def update(self, *args, **kwargs) -> 'NodeDict': for arg in args: From ae3c966f3c25b15d9de2cd467386689f1d9edc47 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 28 Oct 2023 19:17:17 +0800 Subject: [PATCH 6/7] fix bug --- brainpy/_src/dyn/projections/aligns.py | 6 ++-- brainpy/_src/dyn/projections/plasticity.py | 33 +++++++++------------- brainpy/mixin.py | 1 + 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index d8c5a4d47..2616e928b 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -29,7 +29,7 @@ def add_inp_fun(out_label, proj_name, out, post): post.add_inp_fun(out_name, out) -def align_post_init_bef_update(out_label, syn_desc, out_desc, post, proj_name): +def align_post_add_bef_update(out_label, syn_desc, out_desc, post, proj_name): # synapse and output initialization _post_repr = get_post_repr(out_label, syn_desc, out_desc) if not post.has_bef_update(_post_repr): @@ -270,7 +270,7 @@ def __init__( self.comm = comm # synapse and output initialization - syn, out = align_post_init_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) # references self.refs = dict(post=post) # invisible to ``self.nodes()`` @@ -381,7 +381,7 @@ def __init__( delay_cls.register_entry(self.name, delay) # synapse and output initialization - syn, out = align_post_init_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) + syn, out = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) # references self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index e06037273..29858f288 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -8,7 +8,7 @@ from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP) from brainpy.types import ArrayType -from .aligns import (_get_return, align_post_init_bef_update, +from .aligns import (_get_return, align_post_add_bef_update, align_pre2_add_bef_update, add_inp_fun) __all__ = [ @@ -103,7 +103,7 @@ def run(i, I_pre, I_post): return pre_spike, post_spike, g, Apre, Apost, current, W indices = bm.arange(0, duration, bm.dt) - pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post], jit=True) + pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) Args: tau_s: float, ArrayType, Callable. The time constant of :math:`A_{pre}`. @@ -155,25 +155,20 @@ def __init__( delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) + # synapse and output initialization if issubclass(syn.cls, AlignPost): - # synapse and output initialization - syn, out = align_post_init_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) - # references - self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` - self.refs['delay'] = delay_cls - self.refs['syn'] = syn # invisible to ``self.node()`` - self.refs['out'] = out # invisible to ``self.node()`` - + syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, + proj_name=self.name) else: - # synapse initialization - syn = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) - # output initialization - add_inp_fun(out_label, self.name, out(), post) - # references - self.refs = dict(pre=pre, post=post) # invisible to `self.nodes()` - self.refs['delay'] = delay_cls - self.refs['syn'] = syn - self.refs['out'] = out + syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name) + out_cls = out() + add_inp_fun(out_label, self.name, out_cls, post) + + # references + self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` + self.refs['delay'] = delay_cls + self.refs['syn'] = syn_cls # invisible to ``self.node()`` + self.refs['out'] = out_cls # invisible to ``self.node()`` # tracing pre-synaptic spikes using Exponential model self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) diff --git a/brainpy/mixin.py b/brainpy/mixin.py index 9b56befa9..3787e3cf5 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -3,6 +3,7 @@ MixIn as MixIn, AlignPost as AlignPost, ParamDesc as ParamDesc, + ParamDescriber as ParamDescriber, BindCondData as BindCondData, Container as Container, TreeNode as TreeNode, From 6e57e2be2023452f3da4a259cdfe6c0818005775 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 28 Oct 2023 19:59:55 +0800 Subject: [PATCH 7/7] fix bug --- brainpy/_src/delay.py | 5 +- brainpy/_src/dyn/projections/plasticity.py | 17 ++-- .../_src/dyn/projections/tests/test_aligns.py | 83 +++++++++++++------ 3 files changed, 71 insertions(+), 34 deletions(-) diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index 086a1ba87..cc1fb7204 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -467,11 +467,12 @@ def __init__( super().__init__(mode=delay.mode) self.refs = {'delay': delay} assert isinstance(delay, Delay) - delay.register_entry(delay_entry or self.name, time) + self._delay_entry = delay_entry or self.name + delay.register_entry(self._delay_entry, time) self.indices = indices def update(self): - return self.refs['delay'].at(self.name, *self.indices) + return self.refs['delay'].at(self._delay_entry, *self.indices) def reset_state(self, *args, **kwargs): pass diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 29858f288..5894a1452 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -106,10 +106,11 @@ def run(i, I_pre, I_post): pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) Args: - tau_s: float, ArrayType, Callable. The time constant of :math:`A_{pre}`. - tau_t: float, ArrayType, Callable. The time constant of :math:`A_{post}`. - A1: float, ArrayType, Callable. The increment of :math:`A_{pre}` produced by a spike. - A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike. + tau_s: float. The time constant of :math:`A_{pre}`. + tau_t: float. The time constant of :math:`A_{post}`. + A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. + A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. + W_max: float. The maximum weight. pre: DynamicalSystem. The pre-synaptic neuron group. delay: int, float. The pre spike delay length. (ms) syn: DynamicalSystem. The synapse model. @@ -133,6 +134,7 @@ def __init__( tau_t: Union[float, ArrayType, Callable] = 33.7, A1: Union[float, ArrayType, Callable] = 0.96, A2: Union[float, ArrayType, Callable] = 0.53, + W_max: Optional[float] = None, # others out_label: Optional[str] = None, name: Optional[str] = None, @@ -176,6 +178,7 @@ def __init__( self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) # synapse parameters + self.W_max = W_max self.tau_s = parameter(tau_s, sizes=self.pre_num) self.tau_t = parameter(tau_t, sizes=self.post_num) self.A1 = parameter(A1, sizes=self.pre_num) @@ -201,7 +204,7 @@ def update(self): Apre = self.refs['pre_trace'].g Apost = self.refs['post_trace'].g delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike) - self.comm.update_STDP(delta_w) + self.comm.update_STDP(delta_w, constraints=self._weight_clip) # currents current = self.comm(x) @@ -210,3 +213,7 @@ def update(self): else: self.refs['out'].bind_cond(current) # align pre return current + + def _weight_clip(self, w): + return w if self.W_max is None else bm.minimum(w, self.W_max) + diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index 600d82c8e..32b072e5a 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -10,7 +10,7 @@ def test_ProjAlignPreMg1(): class EICOBA_PreAlign(bp.DynamicalSystem): - def __init__(self, scale=1., inp=20.): + def __init__(self, scale=1., inp=20., delay=None): super().__init__() self.inp = inp @@ -22,7 +22,7 @@ def __init__(self, scale=1., inp=20.): self.E2I = bp.dyn.ProjAlignPreMg1( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=None, + delay=delay, comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), out=bp.dyn.COBA(E=0.), post=self.I, @@ -30,7 +30,7 @@ def __init__(self, scale=1., inp=20.): self.E2E = bp.dyn.ProjAlignPreMg1( pre=self.E, syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=None, + delay=delay, comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), out=bp.dyn.COBA(E=0.), post=self.E, @@ -38,7 +38,7 @@ def __init__(self, scale=1., inp=20.): self.I2E = bp.dyn.ProjAlignPreMg1( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=None, + delay=delay, comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), out=bp.dyn.COBA(E=-80.), post=self.E, @@ -46,7 +46,7 @@ def __init__(self, scale=1., inp=20.): self.I2I = bp.dyn.ProjAlignPreMg1( pre=self.I, syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=None, + delay=delay, comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), out=bp.dyn.COBA(E=-80.), post=self.I, @@ -65,13 +65,19 @@ def update(self): indices = np.arange(400) spks = bm.for_loop(net.step_run, indices) bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PreAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + plt.close() bm.clear_buffer_memory() def test_ProjAlignPostMg2(): class EICOBA_PostAlign(bp.DynamicalSystem): - def __init__(self, scale, inp=20., ltc=True): + def __init__(self, scale, inp=20., ltc=True, delay=None): super().__init__() self.inp = inp @@ -86,7 +92,7 @@ def __init__(self, scale, inp=20., ltc=True): self.E2E = bp.dyn.ProjAlignPostMg2( pre=self.E, - delay=None, + delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), out=bp.dyn.COBA.desc(E=0.), @@ -94,7 +100,7 @@ def __init__(self, scale, inp=20., ltc=True): ) self.E2I = bp.dyn.ProjAlignPostMg2( pre=self.E, - delay=None, + delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), out=bp.dyn.COBA.desc(E=0.), @@ -102,7 +108,7 @@ def __init__(self, scale, inp=20., ltc=True): ) self.I2E = bp.dyn.ProjAlignPostMg2( pre=self.I, - delay=None, + delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), out=bp.dyn.COBA.desc(E=-80.), @@ -110,7 +116,7 @@ def __init__(self, scale, inp=20., ltc=True): ) self.I2I = bp.dyn.ProjAlignPostMg2( pre=self.I, - delay=None, + delay=delay, comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), out=bp.dyn.COBA.desc(E=-80.), @@ -131,6 +137,11 @@ def update(self): spks = bm.for_loop(net.step_run, indices) bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + net = EICOBA_PostAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + net = EICOBA_PostAlign(0.5, ltc=False) indices = np.arange(400) spks = bm.for_loop(net.step_run, indices) @@ -178,7 +189,7 @@ def update(self, input): def test_ProjAlignPost2(): class EINet(bp.DynSysGroup): - def __init__(self, scale): + def __init__(self, scale, delay=None): super().__init__() ne, ni = int(3200 * scale), int(800 * scale) p = 80 / (ne + ni) @@ -188,25 +199,25 @@ def __init__(self, scale): self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.E2E = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, + delay=delay, comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), syn=bp.dyn.Expon(size=ne, tau=5.), out=bp.dyn.COBA(E=0.), post=self.E) self.E2I = bp.dyn.ProjAlignPost2(pre=self.E, - delay=0.1, + delay=delay, comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), syn=bp.dyn.Expon(size=ni, tau=5.), out=bp.dyn.COBA(E=0.), post=self.I) self.I2E = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, + delay=delay, comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), syn=bp.dyn.Expon(size=ne, tau=10.), out=bp.dyn.COBA(E=-80.), post=self.E) self.I2I = bp.dyn.ProjAlignPost2(pre=self.I, - delay=0.1, + delay=delay, comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), syn=bp.dyn.Expon(size=ni, tau=10.), out=bp.dyn.COBA(E=-80.), @@ -221,10 +232,16 @@ def update(self, inp): self.I(inp) return self.E.spike - model = EINet(0.5) + model = EINet(0.5, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(0.5, delay=None) indices = bm.arange(400) spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() plt.close() @@ -267,7 +284,7 @@ def update(self, input): def test_ProjAlignPreMg1_v2(): class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): + def __init__(self, scale=1., delay=None): super().__init__() ne, ni = int(3200 * scale), int(800 * scale) p = 80 / (4000 * scale) @@ -277,25 +294,25 @@ def __init__(self, scale=1.): V_initializer=bp.init.Normal(-55., 2.)) self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, + delay=delay, comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E, syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=0.1, + delay=delay, comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, + delay=delay, comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I, syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=0.1, + delay=delay, comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.I) @@ -313,13 +330,19 @@ def update(self, inp): indices = bm.arange(400) spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() plt.close() def test_ProjAlignPreMg2(): class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): + def __init__(self, scale=1., delay=None): super().__init__() ne, ni = int(3200 * scale), int(800 * scale) p = 80 / (4000 * scale) @@ -328,25 +351,25 @@ def __init__(self, scale=1.): self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, + delay=delay, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.E) self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E, - delay=0.1, + delay=delay, syn=bp.dyn.Expon.desc(size=ne, tau=5.), comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), out=bp.dyn.COBA(E=0.), post=self.I) self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, + delay=delay, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), out=bp.dyn.COBA(E=-80.), post=self.E) self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I, - delay=0.1, + delay=delay, syn=bp.dyn.Expon.desc(size=ni, tau=10.), comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), out=bp.dyn.COBA(E=-80.), @@ -361,10 +384,16 @@ def update(self, inp): self.I(inp) return self.E.spike - model = EINet() + model = EINet(scale=0.2, delay=None) indices = bm.arange(400) spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(scale=0.2, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() plt.close()