diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 57df3a401..ce13f0d45 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -483,9 +483,10 @@ def register_delay( def get_delay( self, name: str, - delay_step: Union[int, bm.JaxArray, bm.ndarray] + delay_step: Union[int, bm.JaxArray, bm.ndarray], + indices=None, ): - """Get delay data according to the delay times. + """Get delay data according to the provided delay steps. Parameters ---------- @@ -493,6 +494,8 @@ def get_delay( The delay variable name. delay_step: int, JaxArray, ndarray The delay length. + indices: optional, JaxArray, ndarray + The indices of the delay. Returns ------- @@ -501,14 +504,18 @@ def get_delay( """ if name in self.global_delay_vars: if isinstance(delay_step, int): - return self.global_delay_vars[name](delay_step) + return self.global_delay_vars[name](delay_step, indices) else: - return self.global_delay_vars[name](delay_step, jnp.arange(delay_step.size)) + if indices is None: + indices = jnp.arange(delay_step.size) + return self.global_delay_vars[name](delay_step, indices) elif name in self.local_delay_vars: if isinstance(delay_step, int): return self.local_delay_vars[name](delay_step) else: - return self.local_delay_vars[name](delay_step, jnp.arange(delay_step.size)) + if indices is None: + indices = jnp.arange(delay_step.size) + return self.local_delay_vars[name](delay_step, indices) else: raise ValueError(f'{name} is not defined in delay variables.') diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py index 3750eedd9..6a7c59dc3 100644 --- a/brainpy/dyn/synapses/abstract_models.py +++ b/brainpy/dyn/synapses/abstract_models.py @@ -5,10 +5,9 @@ import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One from brainpy.dyn.base import NeuGroup, TwoEndConn -from brainpy.dyn.utils import init_delay from brainpy.initialize import Initializer, init_param from brainpy.integrators import odeint, JointEq -from brainpy.types import Tensor, Parameter +from brainpy.types import Tensor __all__ = [ 'DeltaSynapse', @@ -246,27 +245,30 @@ class ExpCUBA(TwoEndConn): >>> plt.legend() >>> plt.show() - - **Model Parameters** - - ============= ============== ======== =================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ----------------------------------------------------------------------------------- - delay 0 ms The decay length of the pre-synaptic spikes. - tau_decay 8 ms The time constant of decay. - g_max 1 µmho(µS) The maximum conductance. - ============= ============== ======== =================================================================================== - - **Model Variables** - - ================ ================== ========================================================= - **Member name** **Initial values** **Explanation** - ---------------- ------------------ --------------------------------------------------------- - g 0 Gating variable. - pre_spike False The history spiking states of the pre-synaptic neurons. - ================ ================== ========================================================= - - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau: float + The time constant of decay. [ms] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -280,10 +282,10 @@ def __init__( conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], conn_type: str = 'sparse', g_max: Union[float, Tensor, Initializer, Callable] = 1., - tau: Union[float, Tensor, Initializer, Callable] = 8.0, - method: str = 'exp_auto', + tau: float = 8.0, delay_step: Union[int, Tensor, Initializer, Callable] = None, - name: str = None + name: str = None, + method: str = 'exp_auto', ): super(ExpCUBA, self).__init__(pre=pre, post=post, conn=conn, name=name) self.check_pre_attrs('spike') @@ -291,19 +293,33 @@ def __init__( # parameters self.tau = tau - self.g_max = g_max - # connection + # connections and weights self.conn_type = conn_type if conn_type not in ['sparse', 'dense']: raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') if self.conn is None: raise ValueError(f'Must provide "conn" when initialize the model {self.name}') - if not isinstance(self.conn, (All2All, One2One)): + if isinstance(self.conn, One2One): + self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif isinstance(self.conn, All2All): + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + if bm.size(self.g_max) != 1: + self.weight_type = 'heter' + bm.fill_diagonal(self.g_max, 0.) + else: + self.weight_type = 'homo' + else: if conn_type == 'sparse': self.pre2post = self.conn.require('pre2post') + self.g_max = init_param(g_max, self.pre2post[1].shape, allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' elif conn_type == 'dense': - self.conn_mat = self.conn.require('conn_mat') + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + if self.weight_type == 'homo': + self.conn_mat = self.conn.require('conn_mat') else: raise ValueError(f'Unknown connection type: {conn_type}') @@ -317,31 +333,43 @@ def __init__( def update(self, _t, _dt): # delays if self.delay_step is None: - delayed_pre_spike = self.pre.spike + pre_spike = self.pre.spike else: - delayed_pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step) + pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step) self.update_delay(self.pre.name + '.spike', self.pre.spike) # post values + assert self.weight_type in ['homo', 'heter'] + assert self.conn_type in ['sparse', 'dense'] if isinstance(self.conn, All2All): - post_vs = bm.sum(delayed_pre_spike) - if not self.conn.include_self: - post_vs = post_vs - delayed_pre_spike - post_vs *= self.g_max + if self.weight_type == 'homo': + post_vs = bm.sum(pre_spike) + if not self.conn.include_self: + post_vs = post_vs - pre_spike + post_vs = self.g_max * post_vs + else: + post_vs = bm.expand_dims(pre_spike, 1) * self.g_max + post_vs = post_vs.sum(axis=0) elif isinstance(self.conn, One2One): - post_vs = delayed_pre_spike * self.g_max + post_vs = pre_spike * self.g_max else: if self.conn_type == 'sparse': - post_vs = bm.pre2post_event_sum(delayed_pre_spike, + post_vs = bm.pre2post_event_sum(pre_spike, self.pre2post, self.post.num, self.g_max) else: - post_vs = delayed_pre_spike @ self.conn_mat + if self.weight_type == 'homo': + post_vs = self.g_max * (pre_spike @ self.conn_mat) + else: + post_vs = pre_spike @ self.g_max # updates self.g.value = self.integral(self.g.value, _t, dt=_dt) + post_vs - self.post.input += self.g + self.post.input += self.output(self.g) + + def output(self, g_post): + return g_post class ExpCOBA(ExpCUBA): @@ -391,27 +419,32 @@ class ExpCOBA(ExpCUBA): >>> plt.legend() >>> plt.show() - **Model Parameters** - - ============= ============== ======== =================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ----------------------------------------------------------------------------------- - delay 0 ms The decay length of the pre-synaptic spikes. - tau_decay 8 ms The time constant of decay. - g_max 1 µmho(µS) The maximum conductance. - E 0 mV The reversal potential for the synaptic current. - ============= ============== ======== =================================================================================== - - **Model Variables** - - ================ ================== ========================================================= - **Member name** **Initial values** **Explanation** - ---------------- ------------------ --------------------------------------------------------- - g 0 Gating variable. - pre_spike False The history spiking states of the pre-synaptic neurons. - ================ ================== ========================================================= - - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + E: float + The reversal potential for the synaptic current. [mV] + tau: float + The time constant of decay. [ms] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -444,34 +477,8 @@ def __init__( # parameter self.E = E - def update(self, _t, _dt): - # delays - if self.delay_step is None: - delayed_spike = self.pre.spike - else: - delayed_spike = self.get_delay(self.pre.name + '.spike', self.delay_step) - self.update_delay(self.pre.name + '.spike', self.pre.spike) - - # post values - if isinstance(self.conn, All2All): - post_vs = bm.sum(delayed_spike) - if not self.conn.include_self: - post_vs = post_vs - delayed_spike - post_vs *= self.g_max - elif isinstance(self.conn, One2One): - post_vs = delayed_spike * self.g_max - else: - if self.conn_type == 'sparse': - post_vs = bm.pre2post_event_sum(delayed_spike, - self.pre2post, - self.post.num, - self.g_max) - else: - post_vs = delayed_spike @ self.conn_mat - - # updates - self.g.value = self.integral(self.g, _t, dt=_dt) + post_vs - self.post.input += self.g * (self.E - self.post.V) + def output(self, g_post): + return g_post * (self.E - self.post.V) class DualExpCUBA(TwoEndConn): @@ -539,30 +546,32 @@ class DualExpCUBA(TwoEndConn): >>> plt.legend() >>> plt.show() - - **Model Parameters** - - ============= ============== ======== =================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ----------------------------------------------------------------------------------- - delay 0 ms The decay length of the pre-synaptic spikes. - tau_decay 10 ms The time constant of the synaptic decay phase. - tau_rise 1 ms The time constant of the synaptic rise phase. - g_max 1 µmho(µS) The maximum conductance. - ============= ============== ======== =================================================================================== - - - **Model Variables** - - ================ ================== ========================================================= - **Member name** **Initial values** **Explanation** - ---------------- ------------------ --------------------------------------------------------- - g 0 Synapse conductance on the post-synaptic neuron. - s 0 Gating variable. - pre_spike False The history spiking states of the pre-synaptic neurons. - ================ ================== ========================================================= - - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau_decay: float + The time constant of the synaptic decay phase. [ms] + tau_rise: float + The time constant of the synaptic rise phase. [ms] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -577,9 +586,10 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - g_max: Parameter = 1., - tau_decay: Parameter = 10.0, - tau_rise: Parameter = 1., + conn_type: str = 'dense', + g_max: Union[float, Tensor, Initializer, Callable] = 1., + tau_decay: float = 10.0, + tau_rise: float = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None @@ -591,11 +601,35 @@ def __init__( # parameters self.tau_rise = tau_rise self.tau_decay = tau_decay - self.g_max = g_max # connections - if not isinstance(self.conn, (One2One, All2All)): - self.conn_mat = self.conn.require('conn_mat') + self.conn_type = conn_type + if conn_type not in ['sparse', 'dense']: + raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') + if self.conn is None: + raise ValueError(f'Must provide "conn" when initialize the model {self.name}') + if isinstance(self.conn, One2One): + self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif isinstance(self.conn, All2All): + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + if bm.size(self.g_max) != 1: + self.weight_type = 'heter' + bm.fill_diagonal(self.g_max, 0.) + else: + self.weight_type = 'homo' + else: + if conn_type == 'sparse': + self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') + self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif conn_type == 'dense': + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + if self.weight_type == 'homo': + self.conn_mat = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {conn_type}') # variables self.h = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_)) @@ -603,7 +637,7 @@ def __init__( self.delay_step = self.register_delay(self.pre.name + '.spike', delay_step, self.pre.spike) # integral - self.integral = odeint(method=method, f=self.derivative) + self.integral = odeint(method=method, f=JointEq([self.dg, self.dh])) def dh(self, h, t): return -h / self.tau_rise @@ -611,30 +645,40 @@ def dh(self, h, t): def dg(self, g, t, h): return -g / self.tau_decay + h - @property - def derivative(self): - return JointEq([self.dg, self.dh]) - def update(self, _t, _dt): # delays if self.delay_step is None: - delayed_pre_spike = self.pre.spike + pre_spike = self.pre.spike else: - delayed_pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step) + pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step) self.update_delay(self.pre.name + '.spike', self.pre.spike) - # post-synaptic values + # update synaptic variables self.g.value, self.h.value = self.integral(self.g, self.h, _t, _dt) - self.h += delayed_pre_spike + self.h += pre_spike + + # post-synaptic values + assert self.weight_type in ['homo', 'heter'] + assert self.conn_type in ['sparse', 'dense'] if isinstance(self.conn, All2All): - post_vs = self.g.sum() - if not self.conn.include_self: - post_vs = post_vs - self.g - post_vs = self.g_max * post_vs + if self.weight_type == 'homo': + post_vs = bm.sum(self.g) + if not self.conn.include_self: + post_vs = post_vs - self.g + post_vs = self.g_max * post_vs + else: + post_vs = bm.expand_dims(self.g, 1) * self.g_max + post_vs = post_vs.sum(axis=0) elif isinstance(self.conn, One2One): post_vs = self.g_max * self.g else: - post_vs = self.g_max * self.g @ self.conn_mat + if self.conn_type == 'sparse': + post_vs = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids) + else: + if self.weight_type == 'homo': + post_vs = (self.g_max * self.g) @ self.conn_mat + else: + post_vs = self.g @ self.g_max # output self.post.input += self.output(post_vs) @@ -687,31 +731,34 @@ class DualExpCOBA(DualExpCUBA): >>> plt.legend() >>> plt.show() - - **Model Parameters** - - ============= ============== ======== =================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ----------------------------------------------------------------------------------- - delay 0 ms The decay length of the pre-synaptic spikes. - tau_decay 10 ms The time constant of the synaptic decay phase. - tau_rise 1 ms The time constant of the synaptic rise phase. - g_max 1 µmho(µS) The maximum conductance. - E 0 mV The reversal potential for the synaptic current. - ============= ============== ======== =================================================================================== - - - **Model Variables** - - ================ ================== ========================================================= - **Member name** **Initial values** **Explanation** - ---------------- ------------------ --------------------------------------------------------- - g 0 Synapse conductance on the post-synaptic neuron. - s 0 Gating variable. - pre_spike False The history spiking states of the pre-synaptic neurons. - ================ ================== ========================================================= - - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + E: float + The reversal potential for the synaptic current. [mV] + tau_decay: float + The time constant of the synaptic decay phase. [ms] + tau_rise: float + The time constant of the synaptic rise phase. [ms] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -724,15 +771,16 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - g_max: Parameter = 1., - tau_decay: Parameter = 10.0, - tau_rise: Parameter = 1., - E: Parameter = 0., + conn_type: str = 'dense', + g_max: Union[float, Tensor, Initializer, Callable] = 1., + tau_decay: float = 10.0, + tau_rise: float = 1., + E: float = 0., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None ): - super(DualExpCOBA, self).__init__(pre, post, conn, + super(DualExpCOBA, self).__init__(pre, post, conn, conn_type=conn_type, delay_step=delay_step, g_max=g_max, tau_decay=tau_decay, tau_rise=tau_rise, method=method, name=name) @@ -801,27 +849,30 @@ class AlphaCUBA(DualExpCUBA): >>> plt.legend() >>> plt.show() - **Model Parameters** - - ============= ============== ======== =================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ----------------------------------------------------------------------------------- - delay 0 ms The decay length of the pre-synaptic spikes. - tau_decay 2 ms The decay time constant of the synaptic state. - g_max .2 µmho(µS) The maximum conductance. - ============= ============== ======== =================================================================================== - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - g 0 Synapse conductance on the post-synaptic neuron. - h 0 Gating variable. - pre_spike False The history spiking states of the pre-synaptic neurons. - ================== ================= ========================================================= - - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau_decay: float + The time constant of the synaptic decay phase. [ms] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -833,13 +884,15 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - g_max: Parameter = 1., - tau_decay: Parameter = 10.0, + conn_type: str = 'dense', + g_max: Union[float, Tensor, Initializer, Callable] = 1., + tau_decay: float = 10.0, delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None ): super(AlphaCUBA, self).__init__(pre=pre, post=post, conn=conn, + conn_type=conn_type, delay_step=delay_step, g_max=g_max, tau_decay=tau_decay, @@ -892,30 +945,32 @@ class AlphaCOBA(DualExpCOBA): >>> plt.legend() >>> plt.show() - - **Model Parameters** - - ============= ============== ======== =================================================================================== - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- ----------------------------------------------------------------------------------- - delay 0 ms The decay length of the pre-synaptic spikes. - tau_decay 2 ms The decay time constant of the synaptic state. - g_max .2 µmho(µS) The maximum conductance. - E 0 mV The reversal potential for the synaptic current. - ============= ============== ======== =================================================================================== - - - **Model Variables** - - ================== ================= ========================================================= - **Variables name** **Initial Value** **Explanation** - ------------------ ----------------- --------------------------------------------------------- - g 0 Synapse conductance on the post-synaptic neuron. - h 0 Gating variable. - pre_spike False The history spiking states of the pre-synaptic neurons. - ================== ================= ========================================================= - - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `dense`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + E: float + The reversal potential for the synaptic current. [mV] + tau_decay: float + The time constant of the synaptic decay phase. [ms] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. @@ -928,14 +983,16 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - g_max: Parameter = 1., - tau_decay: Parameter = 10.0, - E: Parameter = 0., + conn_type: str = 'dense', + g_max: Union[float, Tensor, Callable, Initializer] = 1., + tau_decay: float = 10.0, + E: float = 0., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None ): super(AlphaCOBA, self).__init__(pre=pre, post=post, conn=conn, + conn_type=conn_type, delay_step=delay_step, g_max=g_max, E=E, tau_decay=tau_decay, @@ -1028,34 +1085,42 @@ class NMDA(TwoEndConn): >>> plt.legend() >>> plt.show() - **Model Parameters** - - ============= ============== =============== ================================================ - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- --------------- ------------------------------------------------ - delay 0 ms The decay length of the pre-synaptic spikes. - g_max .15 µmho(µS) The synaptic maximum conductance. - E 0 mV The reversal potential for the synaptic current. - alpha .062 \ Binding constant. - beta 3.57 \ Unbinding constant. - cc_Mg 1.2 mM Concentration of Magnesium ion. - tau_decay 100 ms The time constant of the synaptic decay phase. - tau_rise 2 ms The time constant of the synaptic rise phase. - a .5 1/ms - ============= ============== =============== ================================================ - - - **Model Variables** - - =============== ================== ========================================================= - **Member name** **Initial values** **Explanation** - --------------- ------------------ --------------------------------------------------------- - g 0 Synaptic conductance. - x 0 Synaptic gating variable. - pre_spike False The history spiking states of the pre-synaptic neurons. - =============== ================== ========================================================= - - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `dense`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + E: float + The reversal potential for the synaptic current. [mV] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + alpha: float + Binding constant. Default 0.062 + beta: float + Unbinding constant. Default 3.57 + cc_Mg: float + Concentration of Magnesium ion. Default 1.2 [mM]. + tau_decay: float + The time constant of the synaptic decay phase. Default 100 [ms] + tau_rise: float + The time constant of the synaptic rise phase. Default 2 [ms] + a: float + Default 0.5 ms^-1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Brunel N, Wang X J. Effects of neuromodulation in a cortical network model of object working memory dominated @@ -1075,14 +1140,15 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - g_max: Parameter = 0.15, - E: Parameter = 0., - cc_Mg: Parameter = 1.2, - alpha: Parameter = 0.062, - beta: Parameter = 3.57, - tau_decay: Parameter = 100., - a: Parameter = 0.5, - tau_rise: Parameter = 2., + conn_type: str = 'dense', + g_max: Union[float, Tensor, Initializer, Callable] = 0.15, + E: float = 0., + cc_Mg: float = 1.2, + alpha: float = 0.062, + beta: float = 3.57, + tau_decay: float = 100., + a: float = 0.5, + tau_rise: float = 2., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None, @@ -1092,7 +1158,6 @@ def __init__( self.check_post_attrs('input', 'V') # parameters - self.g_max = g_max self.E = E self.alpha = alpha self.beta = beta @@ -1102,8 +1167,33 @@ def __init__( self.a = a # connections - if not isinstance(self.conn, (All2All, One2One)): - self.conn_mat = self.conn.require('conn_mat') + self.conn_type = conn_type + if conn_type not in ['sparse', 'dense']: + raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') + if self.conn is None: + raise ValueError(f'Must provide "conn" when initialize the model {self.name}') + if isinstance(self.conn, One2One): + self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif isinstance(self.conn, All2All): + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + if bm.size(self.g_max) != 1: + self.weight_type = 'heter' + bm.fill_diagonal(self.g_max, 0.) + else: + self.weight_type = 'homo' + else: + if conn_type == 'sparse': + self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') + self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif conn_type == 'dense': + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + if self.weight_type == 'homo': + self.conn_mat = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {conn_type}') # variables self.g = bm.Variable(bm.zeros(self.pre.num, dtype=bm.float_)) @@ -1111,7 +1201,7 @@ def __init__( self.delay_step = self.register_delay(self.pre.name + '.spike', delay_step, self.pre.spike) # integral - self.integral = odeint(method=method, f=self.derivative) + self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) def dg(self, g, t, x): return -g / self.tau_decay + self.a * x * (1 - g) @@ -1119,10 +1209,6 @@ def dg(self, g, t, x): def dx(self, x, t): return -x / self.tau_rise - @property - def derivative(self): - return JointEq([self.dg, self.dx]) - def update(self, _t, _dt): # delays if self.delay_step is None: @@ -1131,18 +1217,32 @@ def update(self, _t, _dt): delayed_pre_spike = self.get_delay(self.pre.name + '.spike', self.delay_step) self.update_delay(self.pre.name + '.spike', self.pre.spike) - # post-synaptic value + # update synapse variables self.g.value, self.x.value = self.integral(self.g, self.x, _t, dt=_dt) self.x += delayed_pre_spike + + # post-synaptic value + assert self.weight_type in ['homo', 'heter'] + assert self.conn_type in ['sparse', 'dense'] if isinstance(self.conn, All2All): - post_g = self.g.sum() - if not self.conn.include_self: - post_g = post_g - self.g + if self.weight_type == 'homo': + post_g = bm.sum(self.g) + if not self.conn.include_self: + post_g = post_g - self.g + else: + post_g = bm.expand_dims(self.g, 1) * self.g_max + post_g = post_g.sum(axis=0) elif isinstance(self.conn, One2One): - post_g = self.g + post_g = self.g_max * self.g else: - post_g = self.g @ self.conn_mat + if self.conn_type == 'sparse': + post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids) + else: + if self.weight_type == 'homo': + post_g = (self.g_max * self.g) @ self.conn_mat + else: + post_g = self.g @ self.g_max # output g_inf = 1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post.V) - self.post.input -= self.g_max * post_g * (self.post.V - self.E) / g_inf + self.post.input -= post_g * (self.post.V - self.E) / g_inf diff --git a/brainpy/dyn/synapses/biological_models.py b/brainpy/dyn/synapses/biological_models.py index ce5c3e50f..5726e7f14 100644 --- a/brainpy/dyn/synapses/biological_models.py +++ b/brainpy/dyn/synapses/biological_models.py @@ -4,11 +4,10 @@ import brainpy.math as bm from brainpy.connect import TwoEndConnector, All2All, One2One -from brainpy.initialize import Initializer from brainpy.dyn.base import NeuGroup, TwoEndConn -from brainpy.dyn.utils import init_delay +from brainpy.initialize import Initializer, init_param from brainpy.integrators import odeint -from brainpy.types import Tensor, Parameter +from brainpy.types import Tensor __all__ = [ 'AMPA', @@ -121,7 +120,8 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - g_max: Union[float, Tensor, Initializer] = 0.42, + conn_type: str = 'dense', + g_max: Union[float, Tensor, Initializer, Callable] = 0.42, E: float = 0., alpha: float = 0.98, beta: float = 0.18, @@ -144,24 +144,43 @@ def __init__( self.T_duration = T_duration # connection - assert self.conn is not None - if not isinstance(self.conn, (All2All, One2One)): - self.conn_mat = self.conn.require('conn_mat') + self.conn_type = conn_type + if conn_type not in ['sparse', 'dense']: + raise ValueError(f'"conn_type" must be in "sparse" and "dense", but we got {conn_type}') + if self.conn is None: + raise ValueError(f'Must provide "conn" when initialize the model {self.name}') + if isinstance(self.conn, One2One): + self.g_max = init_param(g_max, (self.pre.num,), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif isinstance(self.conn, All2All): + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + if bm.size(self.g_max) != 1: + self.weight_type = 'heter' + bm.fill_diagonal(self.g_max, 0.) + else: + self.weight_type = 'homo' + else: + if conn_type == 'sparse': + self.pre_ids, self.post_ids = self.conn.require('pre_ids', 'post_ids') + self.g_max = init_param(g_max, self.post_ids.shape, allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + elif conn_type == 'dense': + self.g_max = init_param(g_max, (self.pre.num, self.post.num), allow_none=False) + self.weight_type = 'heter' if bm.size(self.g_max) != 1 else 'homo' + if self.weight_type == 'homo': + self.conn_mat = self.conn.require('conn_mat') + else: + raise ValueError(f'Unknown connection type: {conn_type}') # variables - if isinstance(self.conn, All2All): - self.g = bm.Variable(bm.zeros(self.pre.num)) - elif isinstance(self.conn, One2One): - self.g = bm.Variable(bm.zeros(self.post.num)) - else: - self.g = bm.Variable(bm.zeros(self.pre.num)) + self.g = bm.Variable(bm.zeros(self.pre.num)) self.spike_arrival_time = bm.Variable(bm.ones(self.pre.num) * -1e7) self.delay_step = self.register_delay(self.pre.name + '.spike', delay_step, self.pre.spike) # functions - self.integral = odeint(method=method, f=self.derivative) + self.integral = odeint(method=method, f=self.dg) - def derivative(self, g, t, TT): + def dg(self, g, t, TT): dg = self.alpha * TT * (1 - g) - self.beta * g return dg @@ -177,23 +196,29 @@ def update(self, _t, _dt): self.spike_arrival_time.value = bm.where(pre_spike, _t, self.spike_arrival_time) # post-synaptic values + TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T + self.g.value = self.integral(self.g, _t, TT, dt=_dt) if isinstance(self.conn, One2One): - TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T - self.g.value = self.integral(self.g, _t, TT, dt=_dt) - g_post = self.g + post_g = self.g_max * self.g elif isinstance(self.conn, All2All): - TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T - self.g.value = self.integral(self.g, _t, TT, dt=_dt) - g_post = self.g.sum() - if not self.conn.include_self: - g_post = g_post - self.g + if self.weight_type == 'homo': + post_g = self.g.sum() * self.g_max + if not self.conn.include_self: + post_g = post_g - self.g + else: + post_g = bm.expand_dims(self.g, 1) * self.g_max + post_g = post_g.sum(axis=0) else: - TT = ((_t - self.spike_arrival_time) < self.T_duration) * self.T - self.g.value = self.integral(self.g, _t, TT, dt=_dt) - g_post = self.g @ self.conn_mat + if self.conn_type == 'sparse': + post_g = bm.pre2post_sum(self.g, self.post.num, self.post_ids, self.pre_ids) + else: + if self.weight_type == 'homo': + post_g = (self.g_max * self.g) @ self.conn_mat + else: + post_g = self.g @ self.g_max # output - self.post.input -= self.g_max * g_post * (self.post.V - self.E) + self.post.input -= post_g * (self.post.V - self.E) class GABAa(AMPA): @@ -220,31 +245,40 @@ class GABAa(AMPA): - `Gamma oscillation network model `_ - **Model Parameters** - - ============= ============== ======== ======================================= - **Parameter** **Init Value** **Unit** **Explanation** - ------------- -------------- -------- --------------------------------------- - delay 0 ms The decay length of the pre-synaptic spikes. - g_max 0.04 µmho(µS) Maximum synapse conductance. - E -80 mV Reversal potential of synapse. - alpha 0.53 \ Activating rate constant of G protein catalyzed by activated GABAb receptor. - beta 0.18 \ De-activating rate constant of G protein. - T 1 mM Transmitter concentration when synapse is triggered by a pre-synaptic spike. - T_duration 1 ms Transmitter concentration duration time after being triggered. - ============= ============== ======== ======================================= - - **Model Variables** - - ================== ================== ================================================== - **Member name** **Initial values** **Explanation** - ------------------ ------------------ -------------------------------------------------- - g 0 Synapse gating variable. - pre_spike False The history of pre-synaptic neuron spikes. - spike_arrival_time -1e7 The arrival time of the pre-synaptic neuron spike. - ================== ================== ================================================== - **References** + Parameters + ---------- + pre: NeuGroup + The pre-synaptic neuron group. + post: NeuGroup + The post-synaptic neuron group. + conn: optional, ndarray, JaxArray, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + conn_type: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `dense`. + delay_step: int, ndarray, JaxArray, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + E: float + The reversal potential for the synaptic current. [mV] + g_max: float, ndarray, JaxArray, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + alpha: float + Binding constant. Default 0.062 + beta: float + Unbinding constant. Default 3.57 + T: float + Transmitter concentration when synapse is triggered by + a pre-synaptic spike.. Default 1 [mM]. + T_duration: float + Transmitter concentration duration time after being triggered. Default 1 [ms] + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity on the integrative properties of neocortical pyramidal neurons @@ -256,19 +290,25 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - g_max: Parameter = 0.04, - E: Parameter = -80., - alpha: Parameter = 0.53, - beta: Parameter = 0.18, - T: Parameter = 1., - T_duration: Parameter = 1., + conn_type: str = 'dense', + g_max: Union[float, Tensor, Initializer, Callable] = 0.04, + E: float = -80., + alpha: float = 0.53, + beta: float = 0.18, + T: float = 1., + T_duration: float = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None ): super(GABAa, self).__init__(pre, post, conn, - delay_step=delay_step, g_max=g_max, E=E, - alpha=alpha, beta=beta, T=T, + conn_type=conn_type, + delay_step=delay_step, + g_max=g_max, + E=E, + alpha=alpha, + beta=beta, + T=T, T_duration=T_duration, method=method, name=name) diff --git a/brainpy/dyn/synapses/learning_rules.py b/brainpy/dyn/synapses/learning_rules.py index 93130a1e6..9699b89fa 100644 --- a/brainpy/dyn/synapses/learning_rules.py +++ b/brainpy/dyn/synapses/learning_rules.py @@ -5,7 +5,7 @@ import brainpy.math as bm from brainpy.connect import TwoEndConnector from brainpy.dyn.base import NeuGroup, TwoEndConn -from brainpy.initialize import init_param, Initializer +from brainpy.initialize import Initializer from brainpy.dyn.utils import init_delay from brainpy.integrators import odeint, JointEq from brainpy.types import Tensor, Parameter @@ -178,11 +178,11 @@ def __init__( pre: NeuGroup, post: NeuGroup, conn: Union[TwoEndConnector, Tensor, Dict[str, Tensor]], - U: Parameter = 0.15, - tau_f: Parameter = 1500., - tau_d: Parameter = 200., - tau: Parameter = 8., - A: Parameter = 1., + U: float = 0.15, + tau_f: float = 1500., + tau_d: float = 200., + tau: float = 8., + A: float = 1., delay_step: Union[int, Tensor, Initializer, Callable] = None, method: str = 'exp_auto', name: str = None diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index d8494d4ac..6696adc5b 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -467,36 +467,15 @@ def argsort(self, axis=-1, kind=None, order=None): """Returns the indices that would sort this array.""" return JaxArray(self.value.argsort(axis=axis, kind=kind, order=order)) - def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): + def astype(self, dtype): """Copy of the array, cast to a specified type. Parameters ---------- dtype: str, dtype Typecode or data-type to which the array is cast. - order : {‘C’, ‘F’, ‘A’, ‘K’}, optional - Controls the memory layout order of the result. - ‘C’ means C order, ‘F’ means Fortran order, ‘A’ means - ‘F’ order if all the arrays are Fortran contiguous, - ‘C’ order otherwise, and ‘K’ means as close to the order - the array elements appear in memory as possible. Default is ‘K’. - casting: {‘no’, ‘equiv’, ‘safe’, ‘same_kind’, ‘unsafe’}, optional - Controls what kind of data casting may occur. - Defaults to ‘unsafe’ for backwards compatibility. - - ‘no’ means the data types should not be cast at all. - - ‘equiv’ means only byte-order changes are allowed. - - ‘safe’ means only casts which can preserve values are allowed. - - ‘same_kind’ means only safe casts or casts within a kind, like float64 to float32, are allowed. - - ‘unsafe’ means any data conversions may be done. - subok: bool, optional - If True, then sub-classes will be passed-through (default), otherwise - the returned array will be forced to be a base-class array. - copy: bool, optional - By default, astype always returns a newly allocated array. - If this is set to false, and the dtype, order, and subok - requirements are satisfied, the input array is returned instead of a copy. """ - return JaxArray(self.value.astype(dtype=dtype, order=order, casting=casting, subok=subok, copy=copy)) + return JaxArray(self.value.astype(dtype=dtype)) def byteswap(self, inplace=False): """Swap the bytes of the array elements @@ -1143,9 +1122,15 @@ def argsort(self, axis=-1, kind=None, order=None): """Returns the indices that would sort this array.""" return self.value.argsort(axis=axis, kind=kind, order=order) - def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): - """Copy of the array, cast to a specified type.""" - return self.value.astype(dtype=dtype, order=order, casting=casting, subok=subok, copy=copy) + def astype(self, dtype): + """Copy of the array, cast to a specified type. + + Parameters + ---------- + dtype: str, dtype + Typecode or data-type to which the array is cast. + """ + return JaxArray(self.value.astype(dtype=dtype)) def byteswap(self, inplace=False): """Swap the bytes of the array elements