From 7718a542decb3335ccba16b428a18a576f3adf29 Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 1 Jun 2023 10:28:03 +0800 Subject: [PATCH 01/14] update package structure --- brainpy/__init__.py | 4 +- brainpy/_src/analysis/highdim/slow_points.py | 2 +- brainpy/_src/analysis/utils/model.py | 2 +- brainpy/_src/{dyn => }/channels/Ca.py | 0 brainpy/_src/{dyn => }/channels/IH.py | 0 brainpy/_src/{dyn => }/channels/K.py | 0 brainpy/_src/{dyn => }/channels/KCa.py | 0 brainpy/_src/{dyn => }/channels/Na.py | 0 brainpy/_src/{dyn => }/channels/__init__.py | 0 brainpy/_src/{dyn => }/channels/base.py | 0 brainpy/_src/{dyn => }/channels/leaky.py | 0 brainpy/_src/dyn/__init__.py | 9 -- brainpy/_src/dyn/_utils.py | 2 +- brainpy/_src/dyn/networks/__init__.py | 1 - brainpy/_src/dyn/networks/cann.py | 25 ------ brainpy/_src/dyn/neurons_v2/lif.py | 85 ------------------- brainpy/_src/dyn/synapses_v2/__init__.py | 0 brainpy/_src/{dyn => }/neurons/__init__.py | 0 .../{dyn => }/neurons/biological_models.py | 0 brainpy/_src/{dyn => }/neurons/compat.py | 0 .../{dyn => }/neurons/fractional_models.py | 0 .../_src/{dyn => }/neurons/input_groups.py | 0 .../_src/{dyn => }/neurons/noise_groups.py | 0 .../_src/{dyn => }/neurons/reduced_models.py | 0 .../neurons/tests/test_reduced_models.py | 0 brainpy/_src/{dyn => }/rates/__init__.py | 0 brainpy/_src/{dyn => }/rates/populations.py | 0 brainpy/_src/{dyn => }/runners.py | 0 .../neurons_v2 => synapses_v2}/__init__.py | 0 .../synapses_v2/abstract_synapses.py | 0 brainpy/_src/{dyn => }/synapses_v2/base.py | 0 brainpy/_src/{dyn => }/synapses_v2/others.py | 0 .../_src/{dyn => }/synapses_v2/syn_outs.py | 0 .../{dyn => }/synapses_v2/syn_plasticity.py | 0 brainpy/_src/train/base.py | 2 +- brainpy/_src/{dyn => }/transform.py | 0 brainpy/channels.py | 14 +-- brainpy/experimental.py | 8 +- brainpy/neurons.py | 10 +-- brainpy/rates.py | 2 +- 40 files changed, 23 insertions(+), 143 deletions(-) rename brainpy/_src/{dyn => }/channels/Ca.py (100%) rename brainpy/_src/{dyn => }/channels/IH.py (100%) rename brainpy/_src/{dyn => }/channels/K.py (100%) rename brainpy/_src/{dyn => }/channels/KCa.py (100%) rename brainpy/_src/{dyn => }/channels/Na.py (100%) rename brainpy/_src/{dyn => }/channels/__init__.py (100%) rename brainpy/_src/{dyn => }/channels/base.py (100%) rename brainpy/_src/{dyn => }/channels/leaky.py (100%) delete mode 100644 brainpy/_src/dyn/networks/__init__.py delete mode 100644 brainpy/_src/dyn/networks/cann.py delete mode 100644 brainpy/_src/dyn/neurons_v2/lif.py delete mode 100644 brainpy/_src/dyn/synapses_v2/__init__.py rename brainpy/_src/{dyn => }/neurons/__init__.py (100%) rename brainpy/_src/{dyn => }/neurons/biological_models.py (100%) rename brainpy/_src/{dyn => }/neurons/compat.py (100%) rename brainpy/_src/{dyn => }/neurons/fractional_models.py (100%) rename brainpy/_src/{dyn => }/neurons/input_groups.py (100%) rename brainpy/_src/{dyn => }/neurons/noise_groups.py (100%) rename brainpy/_src/{dyn => }/neurons/reduced_models.py (100%) rename brainpy/_src/{dyn => }/neurons/tests/test_reduced_models.py (100%) rename brainpy/_src/{dyn => }/rates/__init__.py (100%) rename brainpy/_src/{dyn => }/rates/populations.py (100%) rename brainpy/_src/{dyn => }/runners.py (100%) rename brainpy/_src/{dyn/neurons_v2 => synapses_v2}/__init__.py (100%) rename brainpy/_src/{dyn => }/synapses_v2/abstract_synapses.py (100%) rename brainpy/_src/{dyn => }/synapses_v2/base.py (100%) rename brainpy/_src/{dyn => }/synapses_v2/others.py (100%) rename brainpy/_src/{dyn => }/synapses_v2/syn_outs.py (100%) rename brainpy/_src/{dyn => }/synapses_v2/syn_plasticity.py (100%) rename brainpy/_src/{dyn => }/transform.py (100%) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 1d5055480..72595c984 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -78,8 +78,8 @@ from brainpy._src.context import share from brainpy._src.dynsys import not_pass_shared # running -from brainpy._src.dyn.runners import (DSRunner as DSRunner) -from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,) +from brainpy._src.runners import (DSRunner as DSRunner) +from brainpy._src.transform import (LoopOverTime as LoopOverTime, ) # DynamicalSystem base classes from brainpy._src.dynsys import (DynamicalSystemNS as DynamicalSystemNS, NeuGroupNS as NeuGroupNS, diff --git a/brainpy/_src/analysis/highdim/slow_points.py b/brainpy/_src/analysis/highdim/slow_points.py index bfb703720..4c0b82a87 100644 --- a/brainpy/_src/analysis/highdim/slow_points.py +++ b/brainpy/_src/analysis/highdim/slow_points.py @@ -14,7 +14,7 @@ from brainpy import optim, losses from brainpy._src.analysis import utils, base, constants from brainpy._src.dynsys import DynamicalSystem -from brainpy._src.dyn.runners import check_and_format_inputs, _f_ops +from brainpy._src.runners import check_and_format_inputs, _f_ops from brainpy._src.tools.dicts import DotDict from brainpy.errors import AnalyzerError, UnsupportedError from brainpy.types import ArrayType diff --git a/brainpy/_src/analysis/utils/model.py b/brainpy/_src/analysis/utils/model.py index 8295de1e9..a2c92fc97 100644 --- a/brainpy/_src/analysis/utils/model.py +++ b/brainpy/_src/analysis/utils/model.py @@ -5,7 +5,7 @@ from brainpy._src.math.environment import get_float from brainpy._src.math.interoperability import as_jax from brainpy._src.dynsys import DynamicalSystem -from brainpy._src.dyn.runners import DSRunner +from brainpy._src.runners import DSRunner from brainpy._src.integrators.base import Integrator from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.base import ODEIntegrator diff --git a/brainpy/_src/dyn/channels/Ca.py b/brainpy/_src/channels/Ca.py similarity index 100% rename from brainpy/_src/dyn/channels/Ca.py rename to brainpy/_src/channels/Ca.py diff --git a/brainpy/_src/dyn/channels/IH.py b/brainpy/_src/channels/IH.py similarity index 100% rename from brainpy/_src/dyn/channels/IH.py rename to brainpy/_src/channels/IH.py diff --git a/brainpy/_src/dyn/channels/K.py b/brainpy/_src/channels/K.py similarity index 100% rename from brainpy/_src/dyn/channels/K.py rename to brainpy/_src/channels/K.py diff --git a/brainpy/_src/dyn/channels/KCa.py b/brainpy/_src/channels/KCa.py similarity index 100% rename from brainpy/_src/dyn/channels/KCa.py rename to brainpy/_src/channels/KCa.py diff --git a/brainpy/_src/dyn/channels/Na.py b/brainpy/_src/channels/Na.py similarity index 100% rename from brainpy/_src/dyn/channels/Na.py rename to brainpy/_src/channels/Na.py diff --git a/brainpy/_src/dyn/channels/__init__.py b/brainpy/_src/channels/__init__.py similarity index 100% rename from brainpy/_src/dyn/channels/__init__.py rename to brainpy/_src/channels/__init__.py diff --git a/brainpy/_src/dyn/channels/base.py b/brainpy/_src/channels/base.py similarity index 100% rename from brainpy/_src/dyn/channels/base.py rename to brainpy/_src/channels/base.py diff --git a/brainpy/_src/dyn/channels/leaky.py b/brainpy/_src/channels/leaky.py similarity index 100% rename from brainpy/_src/dyn/channels/leaky.py rename to brainpy/_src/channels/leaky.py diff --git a/brainpy/_src/dyn/__init__.py b/brainpy/_src/dyn/__init__.py index a134e98fc..3ba7d2774 100644 --- a/brainpy/_src/dyn/__init__.py +++ b/brainpy/_src/dyn/__init__.py @@ -4,17 +4,8 @@ Module for brain dynamics model building. """ -from . import ( - channels, neurons, rates, # neuron related - synapses, synouts, synplast, # synapse related - networks, - runners, - transform, -) from .neurons.compat import * -from .runners import * from .synapses.compat import * -from .transform import * diff --git a/brainpy/_src/dyn/_utils.py b/brainpy/_src/dyn/_utils.py index 5cbbdc748..62cab8b79 100644 --- a/brainpy/_src/dyn/_utils.py +++ b/brainpy/_src/dyn/_utils.py @@ -2,8 +2,8 @@ from typing import Optional -from brainpy._src.math.ndarray import Variable from brainpy._src.math.object_transform.base import BrainPyObject +from brainpy._src.math.object_transform.variables import Variable __all__ = [ 'get_output_var', diff --git a/brainpy/_src/dyn/networks/__init__.py b/brainpy/_src/dyn/networks/__init__.py deleted file mode 100644 index 40a96afc6..000000000 --- a/brainpy/_src/dyn/networks/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# -*- coding: utf-8 -*- diff --git a/brainpy/_src/dyn/networks/cann.py b/brainpy/_src/dyn/networks/cann.py deleted file mode 100644 index a5cc812a5..000000000 --- a/brainpy/_src/dyn/networks/cann.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - - -from brainpy._src.dynsys import NeuGroup - -__all__ = [ - 'WuCANN1D', - 'WuCANN2D', -] - - -class WuCANN1D(NeuGroup): - pass - - -class WuCANN2D(NeuGroup): - pass - - -class ACANN_1D(NeuGroup): - pass - - -class ACANN_2D(NeuGroup): - pass diff --git a/brainpy/_src/dyn/neurons_v2/lif.py b/brainpy/_src/dyn/neurons_v2/lif.py deleted file mode 100644 index 7249f7620..000000000 --- a/brainpy/_src/dyn/neurons_v2/lif.py +++ /dev/null @@ -1,85 +0,0 @@ - -from functools import partial -from typing import Union, Callable, Optional -from jax.sharding import Sharding - - -import brainpy.math as bm -from brainpy._src.dynsys import NeuGroupNS -from brainpy._src.context import share -from brainpy._src.initialize import (ZeroInit, - OneInit, - Initializer, - parameter, - variable_, - noise as init_noise) -from brainpy._src.integrators import sdeint, odeint, JointEq -from brainpy.check import is_initializer, is_callable, is_subclass -from brainpy.types import Shape, ArrayType - -__all__ = [ - 'Leaky', -] - -class Leaky(NeuGroupNS): - r"""Leaky Integrator Model. - - **Model Descriptions** - - This class implements a leaky model, in which its dynamics is - given by: - - .. math:: - - x(t + \Delta t) = \exp{-1/\tau \Delta t} x(t) + I - - Parameters - ---------- - size: sequence of int, int - The size of the neuron group. - tau: float, ArrayType, Initializer, callable - Membrane time constant. - method: str - The numerical integration method. - name: str - The group name. - """ - - def __init__( - self, - size: Shape, - keep_size: bool = False, - tau: Union[float, ArrayType, Initializer, Callable] = 10., - name: str = None, - mode: bm.Mode = None, - method: str = 'exp_auto', - ): - super().__init__(size=size, - mode=mode, - keep_size=keep_size, - name=name) - assert self.mode.is_parent_of(bm.TrainingMode, bm.NonBatchingMode) - - # parameters - self.tau = parameter(tau, self.varshape, allow_none=False) - - # integral - self.integral = odeint(method=method, f=self.derivative) - - # variables - self.reset_state(self.mode) - - def derivative(self, x, t): - return -x / self.tau - - def reset_state(self, batch_size=None): - self.x = variable_(bm.zeros, self.varshape, batch_size) - - def update(self, x=None): - t = share.load('t') - dt = share.load('dt') - r = self.integral(self.x.value, t, dt) - if x is not None: - r += x - self.x.value = r - return r diff --git a/brainpy/_src/dyn/synapses_v2/__init__.py b/brainpy/_src/dyn/synapses_v2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/brainpy/_src/dyn/neurons/__init__.py b/brainpy/_src/neurons/__init__.py similarity index 100% rename from brainpy/_src/dyn/neurons/__init__.py rename to brainpy/_src/neurons/__init__.py diff --git a/brainpy/_src/dyn/neurons/biological_models.py b/brainpy/_src/neurons/biological_models.py similarity index 100% rename from brainpy/_src/dyn/neurons/biological_models.py rename to brainpy/_src/neurons/biological_models.py diff --git a/brainpy/_src/dyn/neurons/compat.py b/brainpy/_src/neurons/compat.py similarity index 100% rename from brainpy/_src/dyn/neurons/compat.py rename to brainpy/_src/neurons/compat.py diff --git a/brainpy/_src/dyn/neurons/fractional_models.py b/brainpy/_src/neurons/fractional_models.py similarity index 100% rename from brainpy/_src/dyn/neurons/fractional_models.py rename to brainpy/_src/neurons/fractional_models.py diff --git a/brainpy/_src/dyn/neurons/input_groups.py b/brainpy/_src/neurons/input_groups.py similarity index 100% rename from brainpy/_src/dyn/neurons/input_groups.py rename to brainpy/_src/neurons/input_groups.py diff --git a/brainpy/_src/dyn/neurons/noise_groups.py b/brainpy/_src/neurons/noise_groups.py similarity index 100% rename from brainpy/_src/dyn/neurons/noise_groups.py rename to brainpy/_src/neurons/noise_groups.py diff --git a/brainpy/_src/dyn/neurons/reduced_models.py b/brainpy/_src/neurons/reduced_models.py similarity index 100% rename from brainpy/_src/dyn/neurons/reduced_models.py rename to brainpy/_src/neurons/reduced_models.py diff --git a/brainpy/_src/dyn/neurons/tests/test_reduced_models.py b/brainpy/_src/neurons/tests/test_reduced_models.py similarity index 100% rename from brainpy/_src/dyn/neurons/tests/test_reduced_models.py rename to brainpy/_src/neurons/tests/test_reduced_models.py diff --git a/brainpy/_src/dyn/rates/__init__.py b/brainpy/_src/rates/__init__.py similarity index 100% rename from brainpy/_src/dyn/rates/__init__.py rename to brainpy/_src/rates/__init__.py diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/rates/populations.py similarity index 100% rename from brainpy/_src/dyn/rates/populations.py rename to brainpy/_src/rates/populations.py diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/runners.py similarity index 100% rename from brainpy/_src/dyn/runners.py rename to brainpy/_src/runners.py diff --git a/brainpy/_src/dyn/neurons_v2/__init__.py b/brainpy/_src/synapses_v2/__init__.py similarity index 100% rename from brainpy/_src/dyn/neurons_v2/__init__.py rename to brainpy/_src/synapses_v2/__init__.py diff --git a/brainpy/_src/dyn/synapses_v2/abstract_synapses.py b/brainpy/_src/synapses_v2/abstract_synapses.py similarity index 100% rename from brainpy/_src/dyn/synapses_v2/abstract_synapses.py rename to brainpy/_src/synapses_v2/abstract_synapses.py diff --git a/brainpy/_src/dyn/synapses_v2/base.py b/brainpy/_src/synapses_v2/base.py similarity index 100% rename from brainpy/_src/dyn/synapses_v2/base.py rename to brainpy/_src/synapses_v2/base.py diff --git a/brainpy/_src/dyn/synapses_v2/others.py b/brainpy/_src/synapses_v2/others.py similarity index 100% rename from brainpy/_src/dyn/synapses_v2/others.py rename to brainpy/_src/synapses_v2/others.py diff --git a/brainpy/_src/dyn/synapses_v2/syn_outs.py b/brainpy/_src/synapses_v2/syn_outs.py similarity index 100% rename from brainpy/_src/dyn/synapses_v2/syn_outs.py rename to brainpy/_src/synapses_v2/syn_outs.py diff --git a/brainpy/_src/dyn/synapses_v2/syn_plasticity.py b/brainpy/_src/synapses_v2/syn_plasticity.py similarity index 100% rename from brainpy/_src/dyn/synapses_v2/syn_plasticity.py rename to brainpy/_src/synapses_v2/syn_plasticity.py diff --git a/brainpy/_src/train/base.py b/brainpy/_src/train/base.py index 443f61823..eb19d24d1 100644 --- a/brainpy/_src/train/base.py +++ b/brainpy/_src/train/base.py @@ -4,7 +4,7 @@ import brainpy.math as bm from brainpy._src.dynsys import DynamicalSystem -from brainpy._src.dyn.runners import DSRunner +from brainpy._src.runners import DSRunner from brainpy._src.running import constants as c from brainpy.errors import NoLongerSupportError from brainpy.types import ArrayType, Output diff --git a/brainpy/_src/dyn/transform.py b/brainpy/_src/transform.py similarity index 100% rename from brainpy/_src/dyn/transform.py rename to brainpy/_src/transform.py diff --git a/brainpy/channels.py b/brainpy/channels.py index 16769e2f1..6a19f7f55 100644 --- a/brainpy/channels.py +++ b/brainpy/channels.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from brainpy._src.dyn.channels.base import ( +from brainpy._src.channels.base import ( Ion as Ion, IonChannel as IonChannel, Calcium as Calcium, @@ -11,7 +11,7 @@ LeakyChannel as LeakyChannel, ) -from brainpy._src.dyn.channels.Ca import ( +from brainpy._src.channels.Ca import ( CalciumFixed as CalciumFixed, CalciumDyna as CalciumDyna, CalciumDetailed as CalciumDetailed, @@ -23,12 +23,12 @@ ICaL_IS2008 as ICaL_IS2008, ) -from brainpy._src.dyn.channels.IH import ( +from brainpy._src.channels.IH import ( Ih_HM1992 as Ih_HM1992, Ih_De1996 as Ih_De1996, ) -from brainpy._src.dyn.channels.K import ( +from brainpy._src.channels.K import ( IKDR_Ba2002 as IKDR_Ba2002, IK_TM1991 as IK_TM1991, IK_HH1952 as IK_HH1952, @@ -39,16 +39,16 @@ IKNI_Ya1989 as IKNI_Ya1989, ) -from brainpy._src.dyn.channels.KCa import ( +from brainpy._src.channels.KCa import ( IAHP_De1994 as IAHP_De1994, ) -from brainpy._src.dyn.channels.leaky import ( +from brainpy._src.channels.leaky import ( IL as IL, IKL as IKL, ) -from brainpy._src.dyn.channels.Na import ( +from brainpy._src.channels.Na import ( INa_Ba2002 as INa_Ba2002, INa_TM1991 as INa_TM1991, INa_HH1952 as INa_HH1952, diff --git a/brainpy/experimental.py b/brainpy/experimental.py index 7d182a4a2..68d8ff5bd 100644 --- a/brainpy/experimental.py +++ b/brainpy/experimental.py @@ -1,18 +1,18 @@ -from brainpy._src.dyn.synapses_v2.syn_plasticity import ( +from brainpy._src.synapses_v2.syn_plasticity import ( STD as STD, STP as STP, ) -from brainpy._src.dyn.synapses_v2.syn_outs import ( +from brainpy._src.synapses_v2.syn_outs import ( CUBA as CUBA, COBA as COBA, ) -from brainpy._src.dyn.synapses_v2.abstract_synapses import ( +from brainpy._src.synapses_v2.abstract_synapses import ( Exponential, DualExponential, Alpha, ) -from brainpy._src.dyn.synapses_v2.others import ( +from brainpy._src.synapses_v2.others import ( PoissonInput, ) diff --git a/brainpy/neurons.py b/brainpy/neurons.py index ddc784bd4..0fa154538 100644 --- a/brainpy/neurons.py +++ b/brainpy/neurons.py @@ -1,30 +1,30 @@ # -*- coding: utf-8 -*- -from brainpy._src.dyn.neurons.biological_models import ( +from brainpy._src.neurons.biological_models import ( HH as HH, MorrisLecar as MorrisLecar, PinskyRinzelModel as PinskyRinzelModel, WangBuzsakiModel as WangBuzsakiModel, ) -from brainpy._src.dyn.neurons.fractional_models import ( +from brainpy._src.neurons.fractional_models import ( FractionalNeuron as FractionalNeuron, FractionalFHR as FractionalFHR, FractionalIzhikevich as FractionalIzhikevich, ) -from brainpy._src.dyn.neurons.input_groups import ( +from brainpy._src.neurons.input_groups import ( InputGroup as InputGroup, OutputGroup as OutputGroup, SpikeTimeGroup as SpikeTimeGroup, PoissonGroup as PoissonGroup, ) -from brainpy._src.dyn.neurons.noise_groups import ( +from brainpy._src.neurons.noise_groups import ( OUProcess as OUProcess, ) -from brainpy._src.dyn.neurons.reduced_models import ( +from brainpy._src.neurons.reduced_models import ( Leaky as Leaky, Integrator as Integrator, LeakyIntegrator as LeakyIntegrator, diff --git a/brainpy/rates.py b/brainpy/rates.py index 92dce36e5..7dedee342 100644 --- a/brainpy/rates.py +++ b/brainpy/rates.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -from brainpy._src.dyn.rates.populations import ( +from brainpy._src.rates.populations import ( RateModel as RateModel, FHN as FHN, FeedbackFHN as FeedbackFHN, From 8bca090bb0ea55e71402e668fe48e70c05d8024a Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 1 Jun 2023 10:45:34 +0800 Subject: [PATCH 02/14] update package structure2 --- brainpy/_src/checkpoints/serialization.py | 6 +- brainpy/_src/context.py | 2 +- brainpy/_src/dyn/__init__.py | 4 +- brainpy/_src/dyn/tests/test_access_methods.py | 123 ---------------- brainpy/_src/dyn/tests/test_base_classes.py | 20 --- brainpy/_src/dyn/tests/test_dyn_runner.py | 133 ------------------ brainpy/_src/dyn/tests/test_network.py | 51 ------- brainpy/_src/dyn/tests/test_pickle.py | 22 --- brainpy/_src/dyn/tests/test_slice_view.py | 51 ------- brainpy/_src/{dyn => }/synapses/__init__.py | 0 .../{dyn => }/synapses/abstract_models.py | 2 +- .../{dyn => }/synapses/biological_models.py | 2 +- brainpy/_src/{dyn => }/synapses/compat.py | 2 +- .../{dyn => }/synapses/delay_couplings.py | 2 +- .../_src/{dyn => }/synapses/gap_junction.py | 0 .../_src/{dyn => }/synapses/learning_rules.py | 0 brainpy/_src/{dyn => }/synouts/__init__.py | 0 .../_src/{dyn => }/synouts/conductances.py | 0 brainpy/_src/{dyn => }/synouts/ions.py | 0 brainpy/_src/{dyn => }/synplast/__init__.py | 0 .../synplast/long_term_plasticity.py | 0 .../synplast/short_term_plasticity.py | 0 brainpy/_src/test_check.py | 51 ------- brainpy/synapses/dynamics.py | 8 +- brainpy/synapses/synouts.py | 4 +- brainpy/synapses/synplast.py | 2 +- 26 files changed, 16 insertions(+), 469 deletions(-) delete mode 100644 brainpy/_src/dyn/tests/test_access_methods.py delete mode 100644 brainpy/_src/dyn/tests/test_base_classes.py delete mode 100644 brainpy/_src/dyn/tests/test_dyn_runner.py delete mode 100644 brainpy/_src/dyn/tests/test_network.py delete mode 100644 brainpy/_src/dyn/tests/test_pickle.py delete mode 100644 brainpy/_src/dyn/tests/test_slice_view.py rename brainpy/_src/{dyn => }/synapses/__init__.py (100%) rename brainpy/_src/{dyn => }/synapses/abstract_models.py (99%) rename brainpy/_src/{dyn => }/synapses/biological_models.py (99%) rename brainpy/_src/{dyn => }/synapses/compat.py (99%) rename brainpy/_src/{dyn => }/synapses/delay_couplings.py (99%) rename brainpy/_src/{dyn => }/synapses/gap_junction.py (100%) rename brainpy/_src/{dyn => }/synapses/learning_rules.py (100%) rename brainpy/_src/{dyn => }/synouts/__init__.py (100%) rename brainpy/_src/{dyn => }/synouts/conductances.py (100%) rename brainpy/_src/{dyn => }/synouts/ions.py (100%) rename brainpy/_src/{dyn => }/synplast/__init__.py (100%) rename brainpy/_src/{dyn => }/synplast/long_term_plasticity.py (100%) rename brainpy/_src/{dyn => }/synplast/short_term_plasticity.py (100%) delete mode 100644 brainpy/_src/test_check.py diff --git a/brainpy/_src/checkpoints/serialization.py b/brainpy/_src/checkpoints/serialization.py index 935fbe631..d12f5a1c8 100644 --- a/brainpy/_src/checkpoints/serialization.py +++ b/brainpy/_src/checkpoints/serialization.py @@ -32,10 +32,8 @@ try: from jax.experimental.array_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa except (ModuleNotFoundError, ImportError): - try: - from jax.experimental.gda_serialization import get_tensorstore_spec, GlobalAsyncCheckpointManager # noqa - except (ModuleNotFoundError, ImportError): - get_tensorstore_spec = None + get_tensorstore_spec = None + GlobalAsyncCheckpointManager = None from brainpy._src.math.ndarray import Array from brainpy.errors import (AlreadyExistsError, diff --git a/brainpy/_src/context.py b/brainpy/_src/context.py index 9c953364a..24ace7f80 100644 --- a/brainpy/_src/context.py +++ b/brainpy/_src/context.py @@ -52,7 +52,7 @@ def load(self, key, value: Any = None): return self._arguments[key] if value is None: raise KeyError(f'Cannot found shared data of {key}. ' - f'Please define it with "brainpy.share.save()". ') + f'Please define it with "brainpy.share.save({key}=)". ') else: return value diff --git a/brainpy/_src/dyn/__init__.py b/brainpy/_src/dyn/__init__.py index 3ba7d2774..ec4f94333 100644 --- a/brainpy/_src/dyn/__init__.py +++ b/brainpy/_src/dyn/__init__.py @@ -4,8 +4,8 @@ Module for brain dynamics model building. """ -from .neurons.compat import * -from .synapses.compat import * +from brainpy._src.neurons.compat import * +from brainpy._src.synapses.compat import * diff --git a/brainpy/_src/dyn/tests/test_access_methods.py b/brainpy/_src/dyn/tests/test_access_methods.py deleted file mode 100644 index 1e361ffbd..000000000 --- a/brainpy/_src/dyn/tests/test_access_methods.py +++ /dev/null @@ -1,123 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -import brainpy as bp - -bp.ode.set_default_odeint('rk4') - - -class GABAa(bp.TwoEndConn): - def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., - alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): - super(GABAa, self).__init__(pre=pre, post=post, conn=conn, **kwargs) - - # parameters - self.g_max = g_max - self.E = E - self.alpha = alpha - self.beta = beta - self.T = T - self.T_duration = T_duration - self.delay = delay - - # connections - self.conn_mat = self.conn.requires('conn_mat') - self.size = jnp.shape(self.conn_mat) - - # variables - self.t_last_pre_spike = jnp.ones(self.size) * -1e7 - self.s = jnp.zeros(self.size) - - self.int_s = bp.odeint(self.dev) - - def dev(self, s, t, TT, alpha, beta): - return alpha * TT * (1 - s) - beta * s - - def update(self, t, dt, **kwargs): - spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat - self.t_last_pre_spike = jnp.where(spike, t, self.t_last_pre_spike) - TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T - self.s = self.int_s(self.s, t, TT, self.alpha, self.beta) - self.post.inputs -= jnp.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E) - - -class HH(bp.dyn.NeuGroup): - def __init__(self, size, ENa=55., EK=-90., EL=-65, - C=1.0, gNa=35., gK=9., gL=0.1, V_th=20., - phi=5.0, **kwargs): - super(HH, self).__init__(size=size, **kwargs) - - # parameters - self.ENa = ENa - self.EK = EK - self.EL = EL - self.C = C - self.gNa = gNa - self.gK = gK - self.gL = gL - self.V_th = V_th - self.phi = phi - - # variables - self.V = jnp.ones(self.num) * -65. - self.h = jnp.ones(self.num) * 0.6 - self.n = jnp.ones(self.num) * 0.32 - self.spikes = jnp.zeros(self.num) - self.inputs = jnp.zeros(self.num) - - self.integral = bp.odeint(self.dev) - - def dev(self, V, h, n, t, Iext): - alpha = 0.07 * jnp.exp(-(V + 58) / 20) - beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) - dhdt = alpha * (1 - h) - beta * h - - alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) - beta = 0.125 * jnp.exp(-(V + 44) / 80) - dndt = alpha * (1 - n) - beta * n - - m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) - m_beta = 4 * jnp.exp(-(V + 60) / 18) - m = m_alpha / (m_alpha + m_beta) - INa = self.gNa * m ** 3 * h * (V - self.ENa) - IK = self.gK * n ** 4 * (V - self.EK) - IL = self.gL * (V - self.EL) - dVdt = (- INa - IK - IL + Iext) / self.C - - return dVdt, self.phi * dhdt, self.phi * dndt - - def update(self, t, _i, **kwargs): - V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) - self.spikes[:] = (self.V < self.V_th) * (V >= self.V_th) - self.V[:] = V - self.h[:] = h - self.n[:] = n - self.inputs[:] = 0 - - -def test1(): - bp.math.random.seed(123) - num = 10 - neu = HH(num) - neu.V = -70. + bp.math.random.normal(size=num) * 20 - - syn = GABAa(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) - syn.g_max = 0.1 / num - - net = bp.Network(neu=neu, syn=syn) - - for method in ['relative', 'absolute']: - print(f'Method: {method}\n') - print('vars:') - print('-----') - print('neu.vars()', list(neu.vars(method).keys())) - print('syn.vars()', list(syn.vars(method).keys())) - print('net.vars()', list(net.vars(method).keys())) - print() - - print('nodes:') - print('------') - print('neu.nodes()', list(neu.nodes(method).keys())) - print('syn.nodes()', list(syn.nodes(method).keys())) - print('net.nodes()', list(net.nodes(method).keys())) - print() diff --git a/brainpy/_src/dyn/tests/test_base_classes.py b/brainpy/_src/dyn/tests/test_base_classes.py deleted file mode 100644 index 9c095a30e..000000000 --- a/brainpy/_src/dyn/tests/test_base_classes.py +++ /dev/null @@ -1,20 +0,0 @@ -# -*- coding: utf-8 -*- - -import unittest - -import brainpy as bp - - -class TestDynamicalSystem(unittest.TestCase): - def test_delay(self): - A = bp.neurons.LIF(1) - B = bp.neurons.LIF(1) - C = bp.neurons.LIF(1) - A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1) - A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None) - net = bp.Network(A, B, C, A2B, A2C) - - runner = bp.DSRunner(net,) - runner.run(10.) - - diff --git a/brainpy/_src/dyn/tests/test_dyn_runner.py b/brainpy/_src/dyn/tests/test_dyn_runner.py deleted file mode 100644 index e311a664e..000000000 --- a/brainpy/_src/dyn/tests/test_dyn_runner.py +++ /dev/null @@ -1,133 +0,0 @@ -# -*- coding: utf-8 -*- - - -import unittest -import brainpy as bp -import brainpy.math as bm - - -class TestDSRunner(unittest.TestCase): - def test1(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self, tdi): - self.i += 1 - - ds = ExampleDS() - runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_t_and_dt(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self, tdi): - self.i += 1 * tdi.dt - - runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_DSView(self): - class EINet(bp.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet, self).__init__() - - # network size - num_exc = int(800 * scale) - num_inh = int(200 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - - net = EINet(scale=1., method='exp_auto') - # with JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) - - # without JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)], - jit=False).run(0.2) - - - -class TestMemoryEfficient(unittest.TestCase): - pass - - - - - - -# class TestMonitor(TestCase): -# def test_1d_array(self): -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones(1) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 -# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) -# -# def test_2d_array(): -# set(dt=0.1) -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones((2, 2)) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# def test_monitor_with_every(): -# set(dt=0.1) -# -# # try1: 2d array -# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try1.run(100.) -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# # try2: 1d array -# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try2.a = np.array([1., 1.]) -# try2.run(100.) -# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 2, axis=1) -# assert np.allclose(series, try2.mon.a) -# -# # try2: scalar -# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try3.a = 1. -# try3.run(100.) -# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# assert np.allclose(series, try3.mon.a) diff --git a/brainpy/_src/dyn/tests/test_network.py b/brainpy/_src/dyn/tests/test_network.py deleted file mode 100644 index 3c3afe310..000000000 --- a/brainpy/_src/dyn/tests/test_network.py +++ /dev/null @@ -1,51 +0,0 @@ -import brainpy as bp -import unittest - - -class TestNetDefinition(unittest.TestCase): - def test_define_net1(self): - E = bp.neurons.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60., - tau=20., tau_ref=5., method='exp_auto', - V_initializer=bp.init.Normal(-60., 2.)) - - I = bp.neurons.LIF(800, V_rest=-60., V_th=-50., V_reset=-60., - tau=20., tau_ref=5., method='exp_auto', - V_initializer=bp.init.Normal(-60., 2.)) - - E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), g_max=0.6, - tau=5., output=bp.synouts.COBA(E=0.), - method='exp_auto') - - E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), g_max=0.6, - tau=5., output=bp.synouts.COBA(E=0.), - method='exp_auto') - - I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), g_max=6.7, - tau=10., output=bp.synouts.COBA(E=-80.), - method='exp_auto') - - I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), g_max=6.7, - tau=10., output=bp.synouts.COBA(E=-80.), - method='exp_auto') - - net = bp.Network(E2E, E2I, I2E, I2I, E=E, I=I) - - runner1 = bp.DSRunner(net, - monitors=['E.spike', 'I.spike'], - inputs=[('E.input', 20.), ('I.input', 20.)]) - - runner2 = bp.DSRunner(net, - monitors=[('E.spike', E.spike), ('I.spike', I.spike)], - inputs=[(E.input, 20.), (I.input, 20.)]) - - runner3 = bp.DSRunner(net, - monitors=[('E.spike', E.spike), 'I.spike'], - inputs=[(E.input, 20.), (I.input, 20.)]) - - runner4 = bp.DSRunner(net, - monitors={'E.spike': E.spike, 'I.spike': I.spike}, - inputs=[(E.input, 20.), (I.input, 20.)]) - - bp.math.clear_buffer_memory() - - diff --git a/brainpy/_src/dyn/tests/test_pickle.py b/brainpy/_src/dyn/tests/test_pickle.py deleted file mode 100644 index 2ae6a1345..000000000 --- a/brainpy/_src/dyn/tests/test_pickle.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -import brainpy as bp - -import unittest - -import pickle - - -class TestPickle(unittest.TestCase): - def __init__(self, *args, **kwargs): - super(TestPickle, self).__init__(*args, **kwargs) - - self.pre = bp.neurons.LIF(10) - self.post = bp.neurons.LIF(20) - self.syn = bp.TwoEndConn(self.pre, self.post, bp.conn.FixedProb(0.2)) - self.net = bp.Network(self.pre, self.post, self.syn) - - def test_net(self): - self.skipTest('Currently do not support') - with open('data/net.pickle', 'wb') as f: - pickle.dump(self.net, f) diff --git a/brainpy/_src/dyn/tests/test_slice_view.py b/brainpy/_src/dyn/tests/test_slice_view.py deleted file mode 100644 index a952528fb..000000000 --- a/brainpy/_src/dyn/tests/test_slice_view.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- - - -import brainpy as bp -import brainpy.math as bm -import unittest - - -class TestSliceView(unittest.TestCase): - def test_lif(self): - lif = bp.neurons.LIF(10) - lif_tile = lif[5:] - print(lif_tile.V.shape) - print(lif_tile.varshape) - - print('Before modification: ') - print(lif.V) - lif_tile.V += 10. - - self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 10.]))) - print('After modification 1: ') - print(lif.V) - - lif_tile.V.value = bm.ones(5) * 40. - self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 40.]))) - print('After modification 2: ') - print(lif.V) - - def test_lif_train_mode(self): - lif = bp.neurons.LIF(10, mode=bm.training_mode) - lif_tile = lif[5:] - print(lif_tile.V.shape) - print(lif_tile.varshape) - - print('Before modification: ') - print(lif.V) - lif_tile.V += 10. - - self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 10.]))) - print('After modification 1: ') - print(lif.V) - - lif_tile.V.value = bm.ones((1, 5)) * 40. - self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 40.]))) - print('After modification 2: ') - print(lif.V) - - - - - diff --git a/brainpy/_src/dyn/synapses/__init__.py b/brainpy/_src/synapses/__init__.py similarity index 100% rename from brainpy/_src/dyn/synapses/__init__.py rename to brainpy/_src/synapses/__init__.py diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/synapses/abstract_models.py similarity index 99% rename from brainpy/_src/dyn/synapses/abstract_models.py rename to brainpy/_src/synapses/abstract_models.py index 63dd89236..26e922534 100644 --- a/brainpy/_src/dyn/synapses/abstract_models.py +++ b/brainpy/_src/synapses/abstract_models.py @@ -7,7 +7,7 @@ import brainpy.math as bm from brainpy._src.connect import TwoEndConnector, All2All, One2One -from brainpy._src.dyn.synouts import CUBA, MgBlock +from brainpy._src.synouts import CUBA, MgBlock from brainpy._src.dynsys import NeuGroup, SynOut, SynSTP, TwoEndConn, SynConn from brainpy._src.initialize import Initializer, variable_ from brainpy._src.integrators import odeint, JointEq diff --git a/brainpy/_src/dyn/synapses/biological_models.py b/brainpy/_src/synapses/biological_models.py similarity index 99% rename from brainpy/_src/dyn/synapses/biological_models.py rename to brainpy/_src/synapses/biological_models.py index 4078871a4..c4b126c68 100644 --- a/brainpy/_src/dyn/synapses/biological_models.py +++ b/brainpy/_src/synapses/biological_models.py @@ -7,7 +7,7 @@ import brainpy.math as bm from brainpy._src.dynsys import NeuGroup, TwoEndConn, SynSTP, SynOut -from brainpy._src.dyn.synouts import COBA, MgBlock +from brainpy._src.synouts import COBA, MgBlock from brainpy._src.initialize import Initializer, variable from brainpy._src.integrators import odeint, JointEq from brainpy._src.connect import TwoEndConnector, All2All, One2One diff --git a/brainpy/_src/dyn/synapses/compat.py b/brainpy/_src/synapses/compat.py similarity index 99% rename from brainpy/_src/dyn/synapses/compat.py rename to brainpy/_src/synapses/compat.py index eef7a2108..40b66b5c7 100644 --- a/brainpy/_src/dyn/synapses/compat.py +++ b/brainpy/_src/synapses/compat.py @@ -6,7 +6,7 @@ import brainpy._src.math as bm from brainpy._src.connect import TwoEndConnector from brainpy._src.dynsys import NeuGroup, SynSTP -from brainpy._src.dyn.synouts import COBA, CUBA, MgBlock +from brainpy._src.synouts import COBA, CUBA, MgBlock from brainpy._src.initialize import Initializer from brainpy.types import ArrayType from .abstract_models import Delta, Exponential, DualExponential, NMDA as NewNMDA diff --git a/brainpy/_src/dyn/synapses/delay_couplings.py b/brainpy/_src/synapses/delay_couplings.py similarity index 99% rename from brainpy/_src/dyn/synapses/delay_couplings.py rename to brainpy/_src/synapses/delay_couplings.py index 07794bc2c..c1fd8513b 100644 --- a/brainpy/_src/dyn/synapses/delay_couplings.py +++ b/brainpy/_src/synapses/delay_couplings.py @@ -7,7 +7,7 @@ import brainpy.math as bm from brainpy._src.dynsys import SynConn -from brainpy._src.dyn.neurons.input_groups import InputGroup, OutputGroup +from brainpy._src.neurons.input_groups import InputGroup, OutputGroup from brainpy._src.initialize import Initializer from brainpy.check import is_sequence from brainpy.types import ArrayType diff --git a/brainpy/_src/dyn/synapses/gap_junction.py b/brainpy/_src/synapses/gap_junction.py similarity index 100% rename from brainpy/_src/dyn/synapses/gap_junction.py rename to brainpy/_src/synapses/gap_junction.py diff --git a/brainpy/_src/dyn/synapses/learning_rules.py b/brainpy/_src/synapses/learning_rules.py similarity index 100% rename from brainpy/_src/dyn/synapses/learning_rules.py rename to brainpy/_src/synapses/learning_rules.py diff --git a/brainpy/_src/dyn/synouts/__init__.py b/brainpy/_src/synouts/__init__.py similarity index 100% rename from brainpy/_src/dyn/synouts/__init__.py rename to brainpy/_src/synouts/__init__.py diff --git a/brainpy/_src/dyn/synouts/conductances.py b/brainpy/_src/synouts/conductances.py similarity index 100% rename from brainpy/_src/dyn/synouts/conductances.py rename to brainpy/_src/synouts/conductances.py diff --git a/brainpy/_src/dyn/synouts/ions.py b/brainpy/_src/synouts/ions.py similarity index 100% rename from brainpy/_src/dyn/synouts/ions.py rename to brainpy/_src/synouts/ions.py diff --git a/brainpy/_src/dyn/synplast/__init__.py b/brainpy/_src/synplast/__init__.py similarity index 100% rename from brainpy/_src/dyn/synplast/__init__.py rename to brainpy/_src/synplast/__init__.py diff --git a/brainpy/_src/dyn/synplast/long_term_plasticity.py b/brainpy/_src/synplast/long_term_plasticity.py similarity index 100% rename from brainpy/_src/dyn/synplast/long_term_plasticity.py rename to brainpy/_src/synplast/long_term_plasticity.py diff --git a/brainpy/_src/dyn/synplast/short_term_plasticity.py b/brainpy/_src/synplast/short_term_plasticity.py similarity index 100% rename from brainpy/_src/dyn/synplast/short_term_plasticity.py rename to brainpy/_src/synplast/short_term_plasticity.py diff --git a/brainpy/_src/test_check.py b/brainpy/_src/test_check.py deleted file mode 100644 index a04105486..000000000 --- a/brainpy/_src/test_check.py +++ /dev/null @@ -1,51 +0,0 @@ -# -*- coding: utf-8 -*- - - -import unittest - -from brainpy import check as checking - - -class TestUtils(unittest.TestCase): - def test_check_shape(self): - all_shapes = [ - (1, 2, 3), - (1, 4), - (10, 2, 4) - ] - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=-1) - self.assertEqual(free_shape, [3, 4, 4]) - self.assertEqual(fixed_shapes, [10, 2]) - - def test_check_shape2(self): - all_shapes = [ - (1, 2, 3, 8,), - (10, 1, 4, 10), - (10, 2, 4, 100) - ] - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[2, -1]) - print(free_shape) - print(fixed_shapes) - self.assertEqual(free_shape, [[3, 8], [4, 10], [4, 100]]) - self.assertEqual(fixed_shapes, [10, 2]) - - def test_check_shape3(self): - all_shapes = [ - (1, 2, 3, 8,), - (10, 1, 4, 10), - (10, 2, 4, 100) - ] - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, 2, -1]) - print(free_shape) - print(fixed_shapes) - self.assertEqual(free_shape, [[1, 3, 8], [10, 4, 10], [10, 4, 100]]) - self.assertEqual(fixed_shapes, [2]) - - def test_check_shape4(self): - all_shapes = [ - (1, 2, 3, 8,), - (10, 1, 4, 10), - (10, 2, 4, 100) - ] - with self.assertRaises(ValueError): - free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, -1]) diff --git a/brainpy/synapses/dynamics.py b/brainpy/synapses/dynamics.py index 77e339982..59a8d41b5 100644 --- a/brainpy/synapses/dynamics.py +++ b/brainpy/synapses/dynamics.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from brainpy._src.dyn.synapses.abstract_models import ( +from brainpy._src.synapses.abstract_models import ( Delta as Delta, Exponential as Exponential, DualExponential as DualExponential, @@ -8,17 +8,17 @@ NMDA as NMDA, PoissonInput as PoissonInput, ) -from brainpy._src.dyn.synapses.biological_models import ( +from brainpy._src.synapses.biological_models import ( AMPA as AMPA, GABAa as GABAa, BioNMDA as BioNMDA, ) -from brainpy._src.dyn.synapses.delay_couplings import ( +from brainpy._src.synapses.delay_couplings import ( DelayCoupling as DelayCoupling, DiffusiveCoupling as DiffusiveCoupling, AdditiveCoupling as AdditiveCoupling, ) -from brainpy._src.dyn.synapses.gap_junction import ( +from brainpy._src.synapses.gap_junction import ( GapJunction as GapJunction, ) diff --git a/brainpy/synapses/synouts.py b/brainpy/synapses/synouts.py index 5f66035b2..c8be34142 100644 --- a/brainpy/synapses/synouts.py +++ b/brainpy/synapses/synouts.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- -from brainpy._src.dyn.synouts.conductances import ( +from brainpy._src.synouts.conductances import ( COBA as COBA, CUBA as CUBA, ) -from brainpy._src.dyn.synouts.ions import ( +from brainpy._src.synouts.ions import ( MgBlock as MgBlock, ) diff --git a/brainpy/synapses/synplast.py b/brainpy/synapses/synplast.py index fc32a4286..fed0ab8b3 100644 --- a/brainpy/synapses/synplast.py +++ b/brainpy/synapses/synplast.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from brainpy._src.dyn.synplast.short_term_plasticity import ( +from brainpy._src.synplast.short_term_plasticity import ( STD as STD, STP as STP, ) From caed1d4744b8cfc5062acc6e2505c40e09142a51 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 4 Jun 2023 10:52:49 +0800 Subject: [PATCH 03/14] [random] unify the usage of random --- brainpy/_src/encoding/stateless_encoding.py | 6 ++---- brainpy/_src/inputs/currents.py | 4 ++-- brainpy/_src/integrators/sde/base.py | 2 +- brainpy/_src/layers/dropout.py | 3 +-- brainpy/_src/layers/reservoir.py | 9 ++++----- brainpy/_src/neurons/input_groups.py | 7 +------ brainpy/_src/synapses/abstract_models.py | 10 ++++------ brainpy/_src/synapses_v2/abstract_synapses.py | 2 +- brainpy/_src/synapses_v2/others.py | 10 ++++------ 9 files changed, 20 insertions(+), 33 deletions(-) diff --git a/brainpy/_src/encoding/stateless_encoding.py b/brainpy/_src/encoding/stateless_encoding.py index c32ecdd52..700a6330c 100644 --- a/brainpy/_src/encoding/stateless_encoding.py +++ b/brainpy/_src/encoding/stateless_encoding.py @@ -33,13 +33,11 @@ class PoissonEncoder(Encoder): def __init__(self, min_val: Optional[float] = None, - max_val: Optional[float] = None, - seed: Union[int, ArrayType] = None): + max_val: Optional[float] = None): super().__init__() self.min_val = check.is_float(min_val, 'min_val', allow_none=True) self.max_val = check.is_float(max_val, 'max_val', allow_none=True) - self.rng = bm.random.default_rng(seed) def __call__(self, x: ArrayType, num_step: int = None): """ @@ -66,5 +64,5 @@ def __call__(self, x: ArrayType, num_step: int = None): if not (self.min_val is None or self.max_val is None): x = (x - self.min_val) / (self.max_val - self.min_val) shape = x.shape if (num_step is None) else ((num_step,) + x.shape) - d = bm.as_jax(self.rng.rand(*shape)) < x + d = bm.as_jax(bm.random.rand(*shape)) < x return d.astype(x.dtype) diff --git a/brainpy/_src/inputs/currents.py b/brainpy/_src/inputs/currents.py index e91149572..c63e4d760 100644 --- a/brainpy/_src/inputs/currents.py +++ b/brainpy/_src/inputs/currents.py @@ -260,7 +260,7 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None): dt = bm.get_dt() if dt is None else dt is_float(dt, 'dt', allow_none=False, min_bound=0.) is_integer(n, 'n', allow_none=False, min_bound=0) - rng = bm.random.default_rng(seed) + rng = bm.random.default_rng(seed, clone=False) t_end = duration if t_end is None else t_end i_start = int(t_start / dt) i_end = int(t_end / dt) @@ -302,7 +302,7 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None, dt_sqrt = jnp.sqrt(dt) is_float(dt, 'dt', allow_none=False, min_bound=0.) is_integer(n, 'n', allow_none=False, min_bound=0) - rng = bm.random.default_rng(seed) + rng = bm.random.default_rng(seed, clone=False) x = bm.Variable(jnp.ones(n) * mean) def _f(t): diff --git a/brainpy/_src/integrators/sde/base.py b/brainpy/_src/integrators/sde/base.py index d624dcfb7..504e70073 100644 --- a/brainpy/_src/integrators/sde/base.py +++ b/brainpy/_src/integrators/sde/base.py @@ -75,7 +75,7 @@ def __init__( self.wiener_type = wiener_type # wiener process type # random seed - self.rng = bm.random.default_rng() + self.rng = bm.random.default_rng(clone=False) # code scope self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': self.rng} diff --git a/brainpy/_src/layers/dropout.py b/brainpy/_src/layers/dropout.py index 051b0038c..9be212e18 100644 --- a/brainpy/_src/layers/dropout.py +++ b/brainpy/_src/layers/dropout.py @@ -46,11 +46,10 @@ def __init__( ): super(Dropout, self).__init__(mode=mode, name=name) self.prob = check.is_float(prob, min_bound=0., max_bound=1.) - self.rng = bm.random.default_rng(seed) def update(self, x): if share.load('fit'): - keep_mask = self.rng.bernoulli(self.prob, x.shape) + keep_mask = bm.random.bernoulli(self.prob, x.shape) return bm.where(keep_mask, x / self.prob, 0.) else: return x diff --git a/brainpy/_src/layers/reservoir.py b/brainpy/_src/layers/reservoir.py index cc11fc053..657a13b13 100644 --- a/brainpy/_src/layers/reservoir.py +++ b/brainpy/_src/layers/reservoir.py @@ -127,7 +127,6 @@ def __init__( check.is_callable(self.activation, allow_none=False) self.activation_type = activation_type check.is_string(activation_type, 'activation_type', ['internal', 'external']) - self.rng = bm.random.default_rng(seed) check.is_float(spectral_radius, 'spectral_radius', allow_none=True) self.spectral_radius = spectral_radius @@ -160,7 +159,7 @@ def __init__( self.Wff_shape = weight_shape self.Win = parameter(self._Win_initializer, weight_shape) if self.ff_connectivity < 1.: - conn_mat = self.rng.random(weight_shape) > self.ff_connectivity + conn_mat = bm.random.random(weight_shape) > self.ff_connectivity self.Win[conn_mat] = 0. if self.comp_type == 'sparse' and self.ff_connectivity < 1.: self.ff_pres, self.ff_posts = jnp.where(jnp.logical_not(bm.as_jax(conn_mat))) @@ -172,7 +171,7 @@ def __init__( recurrent_shape = (self.num_unit, self.num_unit) self.Wrec = parameter(self._Wrec_initializer, recurrent_shape) if self.rec_connectivity < 1.: - conn_mat = self.rng.random(recurrent_shape) > self.rec_connectivity + conn_mat = bm.random.random(recurrent_shape) > self.rec_connectivity self.Wrec[conn_mat] = 0. if self.spectral_radius is not None: current_sr = max(abs(jnp.linalg.eig(bm.as_jax(self.Wrec))[0])) @@ -196,7 +195,7 @@ def update(self, x): # inputs x = bm.as_jax(x) if self.noise_ff > 0: - x += self.noise_ff * self.rng.uniform(-1, 1, x.shape) + x += self.noise_ff * bm.random.uniform(-1, 1, x.shape) if self.comp_type == 'sparse' and self.ff_connectivity < 1.: sparse = {'data': self.Win, 'index': (self.ff_pres, self.ff_posts), @@ -215,7 +214,7 @@ def update(self, x): if self.activation_type == 'internal': hidden = self.activation(hidden) if self.noise_rec > 0.: - hidden += self.noise_rec * self.rng.uniform(-1, -1, self.state.shape) + hidden += self.noise_rec * bm.random.uniform(-1, -1, self.state.shape) # new state/output state = (1 - self.leaky_rate) * self.state + self.leaky_rate * hidden if self.activation_type == 'external': diff --git a/brainpy/_src/neurons/input_groups.py b/brainpy/_src/neurons/input_groups.py index eaae570bd..b6240c1bc 100644 --- a/brainpy/_src/neurons/input_groups.py +++ b/brainpy/_src/neurons/input_groups.py @@ -188,18 +188,13 @@ def __init__( self.freqs = parameter(freqs, self.num, allow_none=False) # variables - self.rng = bm.random.default_rng(seed) self.reset_state(self.mode) def update(self): - spikes = self.rng.rand_like(self.spike) <= (self.freqs * share.dt / 1000.) + spikes = bm.random.rand_like(self.spike) <= (self.freqs * share.dt / 1000.) self.spike.value = spikes return spikes - def reset(self, batch_size=None): - self.rng.value = bm.random.default_rng(self.seed) - self.reset_state(batch_size) - def reset_state(self, batch_size=None): self.spike = variable_(lambda s: jnp.zeros(s, dtype=bool), self.varshape, batch_size) diff --git a/brainpy/_src/synapses/abstract_models.py b/brainpy/_src/synapses/abstract_models.py index 26e922534..e7b5bda1a 100644 --- a/brainpy/_src/synapses/abstract_models.py +++ b/brainpy/_src/synapses/abstract_models.py @@ -960,7 +960,6 @@ def __init__( self.freq = freq self.weight = weight self.seed = seed - self.rng = bm.random.default_rng(seed) def update(self, tdi): p = self.freq * tdi.dt / 1e3 @@ -968,14 +967,14 @@ def update(self, tdi): b = self.num_input * (1 - p) if isinstance(tdi.dt, (int, float)): # dt is not in tracing if (a > 5) and (b > 5): - inp = self.rng.normal(a, b * p, self.target_var.shape) + inp = bm.random.normal(a, b * p, self.target_var.shape) else: - inp = self.rng.binomial(self.num_input, p, self.target_var.shape) + inp = bm.random.binomial(self.num_input, p, self.target_var.shape) else: # dt is in tracing inp = bm.cond((a > 5) * (b > 5), - lambda _: self.rng.normal(a, b * p, self.target_var.shape), - lambda _: self.rng.binomial(self.num_input, p, self.target_var.shape), + lambda _: bm.random.normal(a, b * p, self.target_var.shape), + lambda _: bm.random.binomial(self.num_input, p, self.target_var.shape), None) self.target_var += inp * self.weight @@ -987,5 +986,4 @@ def reset_state(self, batch_size=None): pass def reset(self, batch_size=None): - self.rng.seed(self.seed) self.reset_state(batch_size) diff --git a/brainpy/_src/synapses_v2/abstract_synapses.py b/brainpy/_src/synapses_v2/abstract_synapses.py index bc455502d..249e091e0 100644 --- a/brainpy/_src/synapses_v2/abstract_synapses.py +++ b/brainpy/_src/synapses_v2/abstract_synapses.py @@ -7,7 +7,7 @@ import brainpy.math as bm from brainpy._src.connect import TwoEndConnector, All2All, One2One from brainpy._src.context import share -from brainpy._src.dyn.synapses_v2.base import SynConnNS, SynOutNS, SynSTPNS +from brainpy._src.synapses_v2.base import SynConnNS, SynOutNS, SynSTPNS from brainpy._src.initialize import Initializer, variable_ from brainpy._src.integrators import odeint, JointEq from brainpy.check import is_float diff --git a/brainpy/_src/synapses_v2/others.py b/brainpy/_src/synapses_v2/others.py index 463ed255d..0dfb2b105 100644 --- a/brainpy/_src/synapses_v2/others.py +++ b/brainpy/_src/synapses_v2/others.py @@ -52,7 +52,6 @@ def __init__( self.freq = freq self.weight = weight self.seed = seed - self.rng = bm.random.default_rng(seed) def update(self): p = self.freq * share.dt / 1e3 @@ -60,14 +59,14 @@ def update(self): b = self.num_input * (1 - p) if isinstance(share.dt, (int, float)): # dt is not in tracing if (a > 5) and (b > 5): - inp = self.rng.normal(a, b * p, self.target_shape) + inp = bm.random.normal(a, b * p, self.target_shape) else: - inp = self.rng.binomial(self.num_input, p, self.target_shape) + inp = bm.random.binomial(self.num_input, p, self.target_shape) else: # dt is in tracing inp = bm.cond((a > 5) * (b > 5), - lambda _: self.rng.normal(a, b * p, self.target_shape), - lambda _: self.rng.binomial(self.num_input, p, self.target_shape), + lambda _: bm.random.normal(a, b * p, self.target_shape), + lambda _: bm.random.binomial(self.num_input, p, self.target_shape), None) return inp * self.weight @@ -79,7 +78,6 @@ def reset_state(self, batch_size=None): pass def reset(self, batch_size=None): - self.rng.seed(self.seed) self.reset_state(batch_size) From 6c3dd3e8a0dfc7d8db187d78f5141de7668b3f1a Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 5 Jun 2023 10:51:58 +0800 Subject: [PATCH 04/14] [fix] --- brainpy/_src/layers/reservoir.py | 4 ++-- brainpy/_src/math/sparse/__init__.py | 2 ++ brainpy/math/sparse.py | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/layers/reservoir.py b/brainpy/_src/layers/reservoir.py index 657a13b13..6cab48a29 100644 --- a/brainpy/_src/layers/reservoir.py +++ b/brainpy/_src/layers/reservoir.py @@ -200,7 +200,7 @@ def update(self, x): sparse = {'data': self.Win, 'index': (self.ff_pres, self.ff_posts), 'shape': self.Wff_shape} - hidden = bm.sparse_matmul(x, sparse) + hidden = bm.sparse.seg_matmul(x, sparse) else: hidden = x @ self.Win # recurrent @@ -208,7 +208,7 @@ def update(self, x): sparse = {'data': self.Wrec, 'index': (self.rec_pres, self.rec_posts), 'shape': (self.num_unit, self.num_unit)} - hidden += bm.sparse_matmul(self.state, sparse) + hidden += bm.sparse.seg_matmul(self.state, sparse) else: hidden += self.state @ self.Wrec if self.activation_type == 'internal': diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index 572ed6fb7..d45f2c80b 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -4,4 +4,6 @@ from ._utils import * from ._bsr_mv import * from ._bsr_mm import * +from ._jax_prim import * + diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 1610fe13a..1380a9e9c 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,6 +1,8 @@ from brainpy._src.math.sparse import ( - csrmv as csrmv, - coomv as coomv, + csrmv, + coomv, + + seg_matmul, csr_to_dense as csr_to_dense, csr_to_coo as csr_to_coo, From 5b7f22e88453e67423ccc9e2beb4a5be863b9f6a Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 5 Jun 2023 10:52:17 +0800 Subject: [PATCH 05/14] [fix] sparse --- brainpy/_src/math/sparse/_jax_prim.py | 166 ++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 brainpy/_src/math/sparse/_jax_prim.py diff --git a/brainpy/_src/math/sparse/_jax_prim.py b/brainpy/_src/math/sparse/_jax_prim.py new file mode 100644 index 000000000..d60c3ef4c --- /dev/null +++ b/brainpy/_src/math/sparse/_jax_prim.py @@ -0,0 +1,166 @@ +from typing import Union, Dict + +import jax.numpy as jnp +from jax import ops + +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.ndarray import Array + + +__all__ = [ + 'seg_matmul', +] + + +def _matmul_with_left_sparse( + sparse: Dict, + dense: Union[Array, jnp.ndarray] +): + r"""Matrix multiplication with sparse matrix on the left. + + .. math:: + + Y = M_{\mathrm{sparse}} @ M_{\mathrm{dense}} + + Parameters + ---------- + sparse: dict + The sparse matrix with shape of :math:`(N, M)`. + dense: ArrayType + The dense matrix with the shape of :math:`(M, K)`. + + Returns + ------- + matrix + A tensor the the shape of :math:`(N, K)`. + """ + assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' + values = sparse['data'] + rows, cols = sparse['index'] + shape = sparse['shape'] + if len(shape) != 2: + raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') + values = as_jax(values) + rows = as_jax(rows) + cols = as_jax(cols) + dense = as_jax(dense) + B = dense.take(cols, axis=0) + if B.ndim == 2: + prod = B * jnp.reshape(values, (-1, 1)) + else: + prod = B * values + return ops.segment_sum(prod, rows, shape[0]) + + +def _matmul_with_right_sparse( + dense: Union[Array, jnp.ndarray], + sparse: Dict +): + r"""Matrix multiplication with sparse matrix on the left. + + .. math:: + + Y = M_{\mathrm{dense}} @ M_{\mathrm{sparse}} + + Parameters + ---------- + dense: ArrayType + The dense matrix with the shape of :math:`(N, M)`. + sparse: dict + The sparse matrix with shape of :math:`(M, K)`. + + Returns + ------- + matrix + A tensor the the shape of :math:`(N, K)`. + """ + assert dense.ndim in [1, 2], 'Dense matrix must be a one- or two-dimensional matrix.' + values = sparse['data'] + rows, cols = sparse['index'] + shape = sparse['shape'] + if len(shape) != 2: + raise ValueError(f'Sparse matrix must be a two-dimensional matrix. But we got {shape}') + values = as_jax(values) + rows = as_jax(rows) + cols = as_jax(cols) + dense = as_jax(dense) + if dense.ndim == 2: + A = dense[:, rows] + prod = (A * values).T + res = ops.segment_sum(prod, cols, shape[1]).T + else: + prod = dense[rows] * values + res = ops.segment_sum(prod, cols, shape[1]) + return res + + +def seg_matmul(A, B): + r"""Sparse matrix multiplication. + + .. math:: + + y = A @ B + + where :math:`A` or :math:`B` is a sparse matrix. + :math:`A` and :math:`B` cannot be both sparse. + + Examples + -------- + + >>> import brainpy.math as bm + + 1. when the left matrix :math:`A` is a sparse matrix with the shape of :math:`(N, M)`, + + >>> # A is a sparse matrix (3, 4): + >>> # [[0, 2, 0, 4], + >>> # [1, 0, 0, 0], + >>> # [0, 3, 0, 2]] + >>> values = bm.asarray([2, 4, 1, 3, 2]) + >>> rows = bm.asarray([0, 0, 1, 2, 2]) + >>> cols = bm.asarray([1, 3, 0, 1, 3]) + >>> sparse = {'data': values, 'index': (rows, cols), 'shape': (3, 4)} + >>> B = bm.arange(4) + >>> bm.sparse.sparse_matmul(sparse, B) + ArrayType([14, 0, 9], dtype=int32) + >>> B = bm.random.rand(4, 3) + >>> bm.sparse.sparse_matmul(sparse, B) + ArrayType([[3.8331761 , 1.3708692 , 4.510223 ], + [0.9960836 , 0.37550318, 0.7370341 ], + [2.3700516 , 0.7574289 , 4.1124535 ]], dtype=float32) + + 2. when the right matrix :math:`B` is a sparse matrix with the shape of :math:`(M, K)`, + + >>> A = bm.arange(3) + >>> bm.sparse.sparse_matmul(A, sparse) + ArrayType([1, 6, 0, 4], dtype=int32) + >>> A = bm.random.rand(2, 3) + >>> bm.sparse.sparse_matmul(A, sparse) + ArrayType([[0.438388 , 1.4346815 , 0. , 2.361964 ], + [0.9171978 , 1.1214957 , 0. , 0.90534496]], dtype=float32) + + Parameters + ---------- + A: tensor, sequence + The dense or sparse matrix with the shape of :math:`(N, M)`. + B: tensor, sequence + The dense or sparse matrix with the shape of :math:`(M, K)`. + + Returns + ------- + results: ArrayType + The tensor with the shape of :math:`(N, K)`. + """ + if isinstance(A, dict): + if not isinstance(B, (Array, jnp.ndarray)): + raise ValueError('A and B cannot be both sparse. \n' + f'A:\n{A}\n' + f'B:\n{B}') + return _matmul_with_left_sparse(A, B) + else: + if not isinstance(B, dict): + raise ValueError('A and B cannot be both dense. \n' + f'A:\n{A}\n' + f'B:\n{B}') + return _matmul_with_right_sparse(A, B) + + From 4bfe7973b7669c91db97b5642c8aa01771abe01e Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 11:26:32 +0800 Subject: [PATCH 06/14] [math] jit connectivity methods has more data checking --- brainpy/_src/math/jitconn/_event_matvec.py | 24 +++++++++++++++++++++ brainpy/_src/math/jitconn/_matvec.py | 25 +++++++++++++++++++++- 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 6731bd66d..1af2a3aeb 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -10,6 +10,7 @@ from jax.interpreters import xla, ad from jax.lib import xla_client +from brainpy._src.math.ndarray import _get_dtype from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p, mv_prob_uniform_p, @@ -132,6 +133,11 @@ def event_mv_prob_normal( def _event_matvec_prob_homo_abstract( events, weight, clen, seed, *, shape, transpose, outdim_parallel ): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + if events.ndim != 1: raise ValueError('events should be a 1D vector.') if len(shape) != 2: @@ -317,6 +323,15 @@ def _event_matvec_prob_homo_transpose( def _event_matvec_prob_uniform_abstract( events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel ): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_low_dtype = _get_dtype(w_low) + _w_high_dtype = _get_dtype(w_low) + assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' + assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' + assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + if events.ndim != 1: raise ValueError('events should be a 1D vector.') if len(shape) != 2: @@ -493,6 +508,15 @@ def _event_matvec_prob_uniform_transpose( def _event_matvec_prob_normal_abstract( events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel ): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_mu_dtype = _get_dtype(w_mu) + _w_sigma_dtype = _get_dtype(w_sigma) + assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' + assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' + assert _w_sigma_dtype in [jnp.float32, jnp.float64], '"w_sigma" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + if w_mu.ndim != 1: raise ValueError('w_mu should be a 1D scalar.') if w_sigma.ndim != 1: diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 489103202..e0ad0ba91 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -13,7 +13,7 @@ from jax.lib import xla_client from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array +from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_registers import register_general_batching from brainpy.errors import GPUOperatorNotFound, MathError @@ -268,6 +268,11 @@ def mv_prob_normal( def _matvec_prob_homo_abstract( vector, weight, clen, seed, *, shape, transpose, outdim_parallel ): + assert _get_dtype(vector) in [jnp.float32, jnp.float64] + assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + if vector.ndim != 1: raise ValueError('vector should be a 1D vector.') if len(shape) != 2: @@ -451,6 +456,15 @@ def _matvec_prob_homo_transpose( def _matvec_prob_uniform_abstract( vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel ): + assert _get_dtype(vector) in [jnp.float32, jnp.float64] + _w_low_dtype = _get_dtype(w_low) + _w_high_dtype = _get_dtype(w_low) + assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' + assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' + assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + if vector.ndim != 1: raise ValueError('vector should be a 1D vector.') if len(shape) != 2: @@ -623,6 +637,15 @@ def _matvec_prob_uniform_transpose( def _matvec_prob_normal_abstract( vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel ): + assert _get_dtype(vector) in [jnp.float32, jnp.float64] + _w_mu_dtype = _get_dtype(w_mu) + _w_sigma_dtype = _get_dtype(w_sigma) + assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' + assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' + assert _w_sigma_dtype in [jnp.float32, jnp.float64], '"w_sigma" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + if w_mu.ndim != 1: raise ValueError('w_mu should be a 1D scalar.') if w_sigma.ndim != 1: From 744187b679b9a63aca3c3a0fbc8ea4d7466c433a Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 11:27:02 +0800 Subject: [PATCH 07/14] [math] surrogate gradient function repr --- brainpy/_src/math/surrogate/_one_input.py | 56 ++++++++++++++++++++++- brainpy/_src/math/surrogate/base.py | 4 ++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py index a4202b310..5ddb94254 100644 --- a/brainpy/_src/math/surrogate/_one_input.py +++ b/brainpy/_src/math/surrogate/_one_input.py @@ -35,8 +35,6 @@ ] - - class Sigmoid(Surrogate): def __init__(self, alpha=4., origin=False): self.alpha = alpha @@ -45,6 +43,9 @@ def __init__(self, alpha=4., origin=False): def __call__(self, x: Union[jax.Array, Array]): return sigmoid(x, alpha=self.alpha, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=4., origin=False), dict(origin=[True, False])) def sigmoid( @@ -124,6 +125,9 @@ def __init__(self, alpha=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return piecewise_quadratic(x, alpha=self.alpha, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def piecewise_quadratic( @@ -223,6 +227,9 @@ def __init__(self, alpha=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return piecewise_exp(x, alpha=self.alpha, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def piecewise_exp( @@ -308,6 +315,9 @@ def __init__(self, alpha=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return soft_sign(x, alpha=self.alpha, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def soft_sign( @@ -388,6 +398,9 @@ def __init__(self, alpha=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return arctan(x, alpha=self.alpha, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=1., origin=False), dict(origin=[True, False])) def arctan( @@ -467,6 +480,9 @@ def __init__(self, alpha=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return nonzero_sign_log(x, alpha=self.alpha, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]}) def nonzero_sign_log( @@ -559,6 +575,9 @@ def __init__(self, alpha=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return erf(x, alpha=self.alpha, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=1., origin=False), statics={'origin': [True, False]}) def erf( @@ -649,6 +668,9 @@ def __init__(self, c=0.01, w=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return piecewise_leaky_relu(x, c=self.c, w=self.w, origin=self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(c={self.c}, w={self.w})' + @vjp_custom(['x'], dict(c=0.01, w=1., origin=False), statics={'origin': [True, False]}) def piecewise_leaky_relu( @@ -757,6 +779,9 @@ def __init__(self, n=2, t_period=8., origin=False): def __call__(self, x: Union[jax.Array, Array]): return squarewave_fourier_series(x, self.n, self.t_period, self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(n={self.n}, t_period={self.t_period})' + @vjp_custom(['x'], dict(n=2, t_period=8., origin=False), statics={'origin': [True, False]}) def squarewave_fourier_series( @@ -847,6 +872,9 @@ def __init__(self, alpha=4., beta=1., epsilon=1e-8, origin=False): def __call__(self, x: Union[jax.Array, Array], ): return s2nn(x, self.alpha, self.beta, self.epsilon, self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta}, epsilon={self.epsilon})' + @vjp_custom(['x'], defaults=dict(alpha=4., beta=1., epsilon=1e-8, origin=False), @@ -948,6 +976,9 @@ def __init__(self, alpha=2., origin=False): def __call__(self, x: Union[jax.Array, Array]): return q_pseudo_spike(x, self.alpha, self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=2., origin=False), @@ -1039,6 +1070,9 @@ def __init__(self, alpha=0.1, beta=1., origin=False): def __call__(self, x: Union[jax.Array, Array]): return leaky_relu(x, self.alpha, self.beta, self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, beta={self.beta})' + @vjp_custom(['x'], dict(alpha=0.1, beta=1., origin=False), @@ -1129,6 +1163,9 @@ def __init__(self, alpha=0., origin=False): def __call__(self, x: Union[jax.Array, Array]): return log_tailed_relu(x, self.alpha, self.origin) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=0., origin=False), @@ -1230,6 +1267,9 @@ def __init__(self, alpha=0.3, width=1.): def __call__(self, x: Union[jax.Array, Array]): return relu_grad(x, self.alpha, self.width) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, width={self.width})' + @vjp_custom(['x'], dict(alpha=0.3, width=1.)) def relu_grad( @@ -1304,6 +1344,9 @@ def __init__(self, sigma=0.5, alpha=0.5): def __call__(self, x: Union[jax.Array, Array]): return gaussian_grad(x, self.sigma, self.alpha) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha}, sigma={self.sigma})' + @vjp_custom(['x'], dict(sigma=0.5, alpha=0.5)) def gaussian_grad( @@ -1379,6 +1422,9 @@ def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5): def __call__(self, x: Union[jax.Array, Array]): return multi_gaussian_grad(x, self.h, self.s, self.sigma, self.scale) + def __repr__(self): + return f'{self.__class__.__name__}(h={self.h}, s={self.s}, sigma={self.sigma}, scale={self.scale})' + @vjp_custom(['x'], dict(h=0.15, s=6.0, sigma=0.5, scale=0.5)) def multi_gaussian_grad( @@ -1463,6 +1509,9 @@ def __init__(self, alpha=100.): def __call__(self, x: Union[jax.Array, Array]): return inv_square_grad(x, self.alpha) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=100.)) def inv_square_grad( @@ -1528,6 +1577,9 @@ def __init__(self, alpha=1.): def __call__(self, x: Union[jax.Array, Array]): return slayer_grad(x, self.alpha) + def __repr__(self): + return f'{self.__class__.__name__}(alpha={self.alpha})' + @vjp_custom(['x'], dict(alpha=1.)) def slayer_grad( diff --git a/brainpy/_src/math/surrogate/base.py b/brainpy/_src/math/surrogate/base.py index 556462955..dceb58b5c 100644 --- a/brainpy/_src/math/surrogate/base.py +++ b/brainpy/_src/math/surrogate/base.py @@ -10,6 +10,10 @@ class Surrogate(object): def __call__(self, *args, **kwargs): raise NotImplementedError + def __repr__(self): + return f'{self.__class__.__name__}()' + + From 7c03ee0529de4076daee3f979eea53936b89d2a3 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 11:28:28 +0800 Subject: [PATCH 08/14] [math] add `brainpy.math.node_list` and `brainpy.math.node_dict` --- brainpy/_src/math/object_transform/base.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 33e497a47..09282e404 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -27,7 +27,7 @@ __all__ = [ 'BrainPyObject', 'Base', 'FunAsObject', 'ObjectTransform', - 'NodeDict', 'NodeList', + 'NodeDict', 'node_dict', 'NodeList', 'node_list', ] @@ -655,6 +655,8 @@ def extend(self, iterable) -> 'NodeList': return self +node_list = NodeList + class NodeDict(dict): """A dictionary of :py:class:`~.BrainPyObject`, which is compatible with @@ -686,3 +688,6 @@ def __setitem__(self, key, value) -> 'VarDict': super().__setitem__(key, self._check_elem(value)) return self + +node_dict = NodeDict + From 20b44ba81d507afa5dbd9b5e136d3b00c7452559 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 11:29:28 +0800 Subject: [PATCH 09/14] [math] add `brainpy.math.var_list` and `brainpy.math.var_dict` --- brainpy/_src/math/object_transform/variables.py | 9 +++++++-- brainpy/math/object_base.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 24fd9572c..0a0a283df 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -15,8 +15,8 @@ 'Parameter', 'VariableView', - 'VarList', - 'VarDict', + 'VarList', 'var_list', + 'VarDict', 'var_dict', ] @@ -384,6 +384,9 @@ def tree_unflatten(cls, aux_data, children): return cls(children) +var_list = VarList + + @register_pytree_node_class class VarDict(dict): """A dictionary of :py:class:`~.Variable`, which is compatible with @@ -426,3 +429,5 @@ def tree_flatten(self): def tree_unflatten(cls, keys, values): return cls(jax.util.safe_zip(keys, values)) + +var_dict = VarDict diff --git a/brainpy/math/object_base.py b/brainpy/math/object_base.py index 34561e011..1faca0d21 100644 --- a/brainpy/math/object_base.py +++ b/brainpy/math/object_base.py @@ -4,13 +4,17 @@ FunAsObject as FunAsObject) from brainpy._src.math.object_transform.function import (Partial as Partial) from brainpy._src.math.object_transform.base import (NodeList as NodeList, - NodeDict as NodeDict,) + NodeDict as NodeDict, + node_dict as node_dict, + node_list as node_list, ) from brainpy._src.math.object_transform.variables import (Variable as Variable, Parameter as Parameter, TrainVar as TrainVar, VariableView as VariableView, VarList as VarList, - VarDict as VarDict,) + VarDict as VarDict, + var_list as var_list, + var_dict as var_dict, ) From ac7d09d5ef945d415e463a9a338f1d6069508b6c Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 16:33:09 +0800 Subject: [PATCH 10/14] [math] add progress bar options in `brainpy.math.for_loop` --- .../_src/math/object_transform/controls.py | 29 +++++++++++++++++-- .../object_transform/tests/test_controls.py | 10 +++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index d00d0488f..bccfbdadf 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -8,6 +8,8 @@ from jax import lax from jax.errors import UnexpectedTracerError from jax.tree_util import tree_flatten, tree_unflatten +from tqdm.auto import tqdm +from jax.experimental.host_callback import id_tap from brainpy import errors, tools from brainpy._src.math.interoperability import as_jax @@ -625,6 +627,7 @@ def for_loop( unroll: int = 1, remat: bool = False, jit: bool = True, + progress_bar: bool = False, # deprecated dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None, @@ -632,6 +635,8 @@ def for_loop( ): """``for-loop`` control flow with :py:class:`~.Variable`. + .. versionadded:: 2.1.11 + .. versionchanged:: 2.3.0 ``dyn_vars`` has been changed into a default argument. Please change your call from ``for_loop(fun, dyn_vars, operands)`` @@ -672,8 +677,6 @@ def for_loop( [16.] [20.]] - .. versionadded:: 2.1.11 - Parameters ---------- body_fun: callable @@ -699,6 +702,8 @@ def for_loop( Optional positive int specifying, in the underlying operation of the scan primitive, how many scan iterations to unroll within a single iteration of a loop. + progress_bar: bool + Whether we use the progress bar to report the running progress. dyn_vars: Variable, sequence of Variable, dict The instances of :py:class:`~.Variable`. @@ -727,6 +732,10 @@ def for_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) + if progress_bar: + num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) + bar = tqdm(total=num_total) + dyn_vars = get_stack_cache(body_fun) if not jit: if dyn_vars is None: @@ -747,6 +756,8 @@ def fun2scan(carry, x): for k in dyn_vars.keys(): dyn_vars[k]._value = carry[k] results = body_fun(*x) + if progress_bar: + id_tap(lambda *arg: bar.update(), ()) return dyn_vars.dict_data(), results if remat: @@ -761,9 +772,23 @@ def fun2scan(carry, x): unroll=unroll) for key in dyn_vars.keys(): dyn_vars[key]._value = dyn_vals[key] + if progress_bar: + bar.close() return out_vals +def scan( + f: Callable, + init: Any, + xs: Any, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1 +): + jax.lax.scan + + + def while_loop( body_fun: Callable, cond_fun: Callable, diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py index 4dd12d4d7..0dd93eed8 100644 --- a/brainpy/_src/math/object_transform/tests/test_controls.py +++ b/brainpy/_src/math/object_transform/tests/test_controls.py @@ -109,6 +109,11 @@ def f(a): print(ans) print(c) + def test_for_loop_progress_bar(self): + xs = bm.arange(100) + ys = bm.for_loop(lambda a: a, xs, progress_bar=True) + self.assertTrue(bm.allclose(xs, ys)) + class TestIfElse(unittest.TestCase): def test1(self): @@ -225,3 +230,8 @@ def body(x, y): self.assertTrue(bm.array_equal(res2[0], res[0])) self.assertTrue(bm.array_equal(res2[1], res[1])) + + + + + From 2f71a8a3fb28bb82bf516864039b957d74203e24 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 22:17:00 +0800 Subject: [PATCH 11/14] [parallelization] new module `brainpy.pnn` for model parallelization --- brainpy/__init__.py | 35 +- brainpy/_src/delay.py | 72 +- brainpy/_src/dynsys.py | 15 +- brainpy/_src/initialize/generic.py | 146 +-- brainpy/_src/integrators/ode/exponential.py | 4 +- brainpy/_src/math/object_transform/jit.py | 49 +- brainpy/_src/math/sharding.py | 112 ++ brainpy/_src/pnn/__init__.py | 1 + .../channels/__init__.py} | 0 brainpy/_src/pnn/common.py | 143 +++ brainpy/_src/pnn/delay.py | 538 ++++++++++ brainpy/_src/pnn/mixin.py | 54 + brainpy/_src/pnn/neurons/__init__.py | 5 + brainpy/_src/pnn/neurons/_docs.py | 32 + brainpy/_src/pnn/neurons/base.py | 125 +++ brainpy/_src/pnn/neurons/hh.py | 0 brainpy/_src/pnn/neurons/lif.py | 315 ++++++ brainpy/_src/pnn/synapses/__init__.py | 0 brainpy/_src/pnn/synapses/projections.py | 179 ++++ brainpy/_src/pnn/synapses/syn_comm.py | 322 ++++++ brainpy/_src/pnn/synapses/syn_dynamics.py | 960 ++++++++++++++++++ brainpy/_src/pnn/synapses/syn_output.py | 159 +++ brainpy/_src/pnn/utils/__init__.py | 6 + brainpy/_src/pnn/utils/axis_names.py | 9 + brainpy/_src/pnn/utils/axis_rules.py | 8 + brainpy/_src/pnn/utils/init.py | 30 + brainpy/_src/psnn/__init__.py | 5 - brainpy/_src/rates/populations.py | 9 +- brainpy/_src/runners.py | 1 + brainpy/_src/synapses_v2/abstract_synapses.py | 31 +- brainpy/_src/synapses_v2/syn_outs.py | 2 +- brainpy/_src/synapses_v2/syn_plasticity.py | 2 +- brainpy/_src/tests/test_access_methods.py | 123 +++ brainpy/_src/tests/test_base_classes.py | 20 + brainpy/_src/tests/test_check.py | 51 + brainpy/_src/tests/test_dyn_runner.py | 133 +++ brainpy/_src/tests/test_network.py | 51 + brainpy/_src/tests/test_pickle.py | 22 + brainpy/_src/tests/test_slice_view.py | 51 + brainpy/_src/tools/codes.py | 6 + brainpy/_src/tools/others.py | 5 +- brainpy/check.py | 5 +- brainpy/math/__init__.py | 2 + brainpy/math/sharding.py | 7 + brainpy/pnn/__init__.py | 7 + brainpy/pnn/mixin.py | 4 + brainpy/pnn/pchannels.py | 0 brainpy/pnn/pneurons.py | 11 + brainpy/pnn/pother_models.py | 8 + brainpy/pnn/psynapses.py | 28 + brainpy/pnn/putils.py | 6 + examples/dynamics_simulation/COBA.py | 39 +- 52 files changed, 3770 insertions(+), 178 deletions(-) create mode 100644 brainpy/_src/math/sharding.py create mode 100644 brainpy/_src/pnn/__init__.py rename brainpy/_src/{psnn/neurons_abstract.py => pnn/channels/__init__.py} (100%) create mode 100644 brainpy/_src/pnn/common.py create mode 100644 brainpy/_src/pnn/delay.py create mode 100644 brainpy/_src/pnn/mixin.py create mode 100644 brainpy/_src/pnn/neurons/__init__.py create mode 100644 brainpy/_src/pnn/neurons/_docs.py create mode 100644 brainpy/_src/pnn/neurons/base.py create mode 100644 brainpy/_src/pnn/neurons/hh.py create mode 100644 brainpy/_src/pnn/neurons/lif.py create mode 100644 brainpy/_src/pnn/synapses/__init__.py create mode 100644 brainpy/_src/pnn/synapses/projections.py create mode 100644 brainpy/_src/pnn/synapses/syn_comm.py create mode 100644 brainpy/_src/pnn/synapses/syn_dynamics.py create mode 100644 brainpy/_src/pnn/synapses/syn_output.py create mode 100644 brainpy/_src/pnn/utils/__init__.py create mode 100644 brainpy/_src/pnn/utils/axis_names.py create mode 100644 brainpy/_src/pnn/utils/axis_rules.py create mode 100644 brainpy/_src/pnn/utils/init.py delete mode 100644 brainpy/_src/psnn/__init__.py create mode 100644 brainpy/_src/tests/test_access_methods.py create mode 100644 brainpy/_src/tests/test_base_classes.py create mode 100644 brainpy/_src/tests/test_check.py create mode 100644 brainpy/_src/tests/test_dyn_runner.py create mode 100644 brainpy/_src/tests/test_network.py create mode 100644 brainpy/_src/tests/test_pickle.py create mode 100644 brainpy/_src/tests/test_slice_view.py create mode 100644 brainpy/math/sharding.py create mode 100644 brainpy/pnn/__init__.py create mode 100644 brainpy/pnn/mixin.py create mode 100644 brainpy/pnn/pchannels.py create mode 100644 brainpy/pnn/pneurons.py create mode 100644 brainpy/pnn/pother_models.py create mode 100644 brainpy/pnn/psynapses.py create mode 100644 brainpy/pnn/putils.py diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 72595c984..1aca9aa94 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -51,13 +51,15 @@ # Part 3: Models # # ---------------- # -from brainpy import (channels, # channel models - layers, # ANN layers - neurons, # neuron groups - synapses, # synapses - rates, # rate models - experimental, - ) +from brainpy import ( + channels, # channel models + layers, # ANN layers + neurons, # neuron groups + synapses, # synapses + rates, # rate models + experimental, + pnn, # parallel SNN models +) from brainpy.synapses import (synouts, # synaptic output synplast, ) # synaptic plasticity @@ -81,13 +83,14 @@ from brainpy._src.runners import (DSRunner as DSRunner) from brainpy._src.transform import (LoopOverTime as LoopOverTime, ) # DynamicalSystem base classes -from brainpy._src.dynsys import (DynamicalSystemNS as DynamicalSystemNS, - NeuGroupNS as NeuGroupNS, - TwoEndConnNS as TwoEndConnNS, - ) -from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS, - SynSTPNS as SynSTPNS, - SynConnNS as SynConnNS, ) +from brainpy._src.dynsys import ( + DynamicalSystemNS as DynamicalSystemNS, + NeuGroupNS as NeuGroupNS, + TwoEndConnNS as TwoEndConnNS, +) +from brainpy._src.synapses_v2.base import (SynOutNS as SynOutNS, + SynSTPNS as SynSTPNS, + SynConnNS as SynConnNS, ) # Part 4: Training # @@ -114,6 +117,8 @@ # ---------------------- # +math.__dict__['sparse_matmul'] = math.sparse.seg_matmul + math.__dict__['event_matvec_prob_conn_homo_weight'] = math.jitconn.event_mv_prob_homo math.__dict__['event_matvec_prob_conn_uniform_weight'] = math.jitconn.event_mv_prob_uniform math.__dict__['event_matvec_prob_conn_normal_weight'] = math.jitconn.event_mv_prob_normal @@ -242,7 +247,7 @@ dyn.__dict__['OUProcess'] = neurons.OUProcess # synapses -from brainpy._src.dyn.synapses import compat +from brainpy._src.synapses import compat dyn.__dict__['DeltaSynapse'] = compat.DeltaSynapse dyn.__dict__['ExpCUBA'] = compat.ExpCUBA dyn.__dict__['ExpCOBA'] = compat.ExpCOBA diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index 3035b0131..29300027d 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -24,37 +24,37 @@ class Delay(DynamicalSystemNS): """Delay variable which has a fixed delay length. - The data in this delay variable is arranged as:: - - delay = 0 [ data - delay = 1 data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - - Parameters - ---------- - latest: Variable - The initial delay data. - length: int - The delay data length. - before_t0: Any - The delay data. It can be a Python number, like float, int, boolean values. - It can also be arrays. Or a callable function or instance of ``Connector``. - Note that ``initial_delay_data`` should be arranged as the following way:: - - delay = 1 [ data - delay = 2 data - ... .... - ... .... - delay = length-1 data - delay = length data ] - method: str - The method used for updating delay. - - """ + The data in this delay variable is arranged as:: + + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + Parameters + ---------- + latest: Variable + The initial delay data. + length: int + The delay data length. + before_t0: Any + The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + method: str + The method used for updating delay. + + """ latest: bm.Variable data: Optional[bm.Variable] @@ -64,9 +64,9 @@ def __init__( self, latest: bm.Variable, length: int = 0, - before_t0: Union[float, int, bool, bm.Array, jax.Array, Callable] = None, + before_t0: Optional[Union[float, int, bool, bm.Array, jax.Array, Callable]] = None, entries: Optional[Dict] = None, - name: str = None, + name: Optional[str] = None, method: str = ROTATE_UPDATE, mode: Optional[bm.Mode] = None, ): @@ -249,7 +249,10 @@ def retrieve(self, delay_step, *indices): # the delay data return self.data[indices] - def update(self, latest_value: Optional[Union[bm.Array, jax.Array]] = None) -> None: + def update( + self, + latest_value: Optional[Union[bm.Array, jax.Array]] = None + ) -> None: """Update delay variable with the new data. """ if self.data is not None: @@ -297,4 +300,3 @@ def _init_data(self, length, batch_size: int = None): self.data[1:] = self._before_t0 elif callable(self._before_t0): self.data[1:] = self._before_t0((length,) + self.latest.shape, dtype=self.latest.dtype) - diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 49a2e5b7d..7164f0457 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -32,7 +32,7 @@ 'Channel', # neuron models - 'NeuGroup', 'CondNeuGroup', + 'NeuGroup', 'CondNeuGroup', 'NeuGroupNS', # synapse models 'SynConn', @@ -113,6 +113,9 @@ class DynamicalSystem(BrainPyObject): The model computation mode. It should be instance of :py:class:`~.Mode`. """ + supported_modes: Optional[Sequence[bm.Mode]] = None + '''Supported computing modes.''' + _pass_shared_args: bool = True global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict() @@ -132,6 +135,12 @@ def __init__( f'but we got {type(mode)}: {mode}') self._mode = mode + if self.supported_modes is not None: + if not self.mode.is_parent_of(*self.supported_modes): + raise UnsupportedError(f'The mode only supports computing modes ' + f'which are parents of {self.supported_modes}, ' + f'but we got {self.mode}.') + # local delay variables self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector() @@ -647,6 +656,7 @@ def __init__( f'But we got {type(size)}') self.size = size self.keep_size = keep_size + # number of neurons self.num = tools.size2num(size) @@ -667,7 +677,7 @@ def get_batch_shape(self, batch_size=None): else: return (batch_size,) + self.varshape - def update(self, *args): + def update(self, *args, **kwargs): """The function to specify the updating rule. """ raise NotImplementedError(f'Subclass of {self.__class__.__name__} must ' @@ -682,7 +692,6 @@ def __getitem__(self, item): return NeuGroupView(target=self, index=item) - class SynConn(DynamicalSystem): """Base class to model two-end synaptic connections. diff --git a/brainpy/_src/initialize/generic.py b/brainpy/_src/initialize/generic.py index e1df3d515..6c5b6bc5e 100644 --- a/brainpy/_src/initialize/generic.py +++ b/brainpy/_src/initialize/generic.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Union, Callable, Optional +from typing import Union, Callable, Optional, Sequence import jax.numpy as jnp import numpy as np @@ -10,23 +10,30 @@ from brainpy.types import Shape, ArrayType from .base import Initializer + __all__ = [ 'parameter', 'variable', 'variable_', 'noise', 'delay', - - # deprecated - 'init_param', ] +def _check_none(x, allow_none: bool = False): + pass + + +def _is_scalar(x): + return isinstance(x, (float, int, bool, complex)) + + def parameter( param: Union[Callable, Initializer, bm.ndarray, np.ndarray, jnp.ndarray, float, int, bool], - size: Shape, + sizes: Shape, allow_none: bool = True, allow_scalar: bool = True, + axis_names: Optional[Sequence[str]] = None ): """Initialize parameters. @@ -38,12 +45,14 @@ def parameter( - If it is a callable function :math:`f`, the ``f(size)`` will be returned. - If it is an instance of :py:class:`brainpy.init.Initializer``, the ``f(size)`` will be returned. - If it is a tensor, then this function check whether ``tensor.shape`` is equal to the given ``size``. - size: int, sequence of int + sizes: int, sequence of int The shape of the parameter. allow_none: bool Whether allow the parameter is None. allow_scalar: bool Whether allow the parameter is a scalar value. + axis_names: sequence of str + The axes for automatic array sharding. Returns ------- @@ -60,11 +69,14 @@ def parameter( else: raise ValueError(f'Expect a parameter with type of float, ArrayType, Initializer, or ' f'Callable function, but we got None. ') - size = to_size(size) - if allow_scalar and isinstance(param, (float, int, bool)): + sizes = to_size(sizes) + if allow_scalar and _is_scalar(param): return param + if callable(param): - param = param(size) + param = param(sizes) # TODO + # return bm.jit(param, static_argnums=0, out_shardings=bm.sharding.get_sharding(axis_names))(size) + elif isinstance(param, (np.ndarray, jnp.ndarray)): param = bm.asarray(param) elif isinstance(param, bm.Variable): @@ -73,32 +85,22 @@ def parameter( param = param else: raise ValueError(f'Unknown param type {type(param)}: {param}') + if allow_scalar: if param.shape == () or param.shape == (1,): return param - if param.shape != size: - raise ValueError(f'The shape of the parameters should be {size}, but we got {param.shape}') - return param - - -def init_param( - param: Union[Callable, Initializer, bm.ndarray, jnp.ndarray, float, int, bool], - size: Shape, - allow_none: bool = True, -): - """Initialize parameters. Same as ``parameter()``. - - .. deprecated:: 2.2.3.4 - Will be removed since version 2.4.0. - """ - return parameter(param, size, allow_none) + if param.shape != sizes: + raise ValueError(f'The shape of the parameters should be {sizes}, but we got {param.shape}') + return bm.sharding.partition_by_axname(param, axis_names) def variable_( init: Union[Callable, ArrayType], - size: Shape = None, - batch_size_or_mode: Optional[Union[int, bool, bm.Mode]] = None, + sizes: Shape = None, + batch_or_mode: Optional[Union[int, bool, bm.Mode]] = None, batch_axis: int = 0, + axis_names: Optional[Sequence[str]] = None, + batch_axis_name: Optional[str] = None, ): """Initialize a :math:`~.Variable` from a callable function or a data. @@ -106,15 +108,19 @@ def variable_( ---------- init: callable, function, ArrayType The data to be initialized as a ``Variable``. - batch_size_or_mode: int, bool, Mode, optional + batch_or_mode: int, bool, Mode, optional The batch size, model ``Mode``, boolean state. This is used to specify the batch size of this variable. If it is a boolean or an instance of ``Mode``, the batch size will be 1. If it is None, the variable has no batch axis. - size: Shape + sizes: Shape The shape of the variable. batch_axis: int The batch axis. + axis_names: sequence of str + The name for each axis. These names should match the given ``axes``. + batch_axis_name: str + The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given. Returns ------- @@ -126,14 +132,21 @@ def variable_( variable, parameter, noise, delay """ - return variable(init, batch_size_or_mode, size, batch_axis) + return variable(init, + batch_or_mode, + sizes=sizes, + batch_axis=batch_axis, + axis_names=axis_names, + batch_axis_name=batch_axis_name) def variable( init: Union[Callable, ArrayType], - batch_size_or_mode: Optional[Union[int, bool, bm.Mode]] = None, - size: Shape = None, + batch_or_mode: Optional[Union[int, bool, bm.Mode]] = None, + sizes: Shape = None, batch_axis: int = 0, + axis_names: Optional[Sequence[str]] = None, + batch_axis_name: Optional[str] = None, ): """Initialize variables. @@ -141,15 +154,19 @@ def variable( ---------- init: callable, function, ArrayType The data to be initialized as a ``Variable``. - batch_size_or_mode: int, bool, Mode, optional + batch_or_mode: int, bool, Mode, optional The batch size, model ``Mode``, boolean state. This is used to specify the batch size of this variable. If it is a boolean or an instance of ``Mode``, the batch size will be 1. If it is None, the variable has no batch axis. - size: Shape + sizes: Shape The shape of the variable. batch_axis: int The batch axis. + axis_names: sequence of str + The name for each axis. These names should match the given ``axes``. + batch_axis_name: str + The name for the batch axis. The name will be used if ``batch_size_or_mode`` is given. Returns ------- @@ -161,43 +178,52 @@ def variable( variable_, parameter, noise, delay """ - size = to_size(size) + + sizes = to_size(sizes) + if axis_names is not None: + axis_names = list(axis_names) + assert len(sizes) == len(axis_names) + if batch_or_mode is not None and not isinstance(batch_or_mode, bm.NonBatchingMode): + assert batch_axis_name is not None + axis_names.insert(batch_axis, batch_axis_name) + if callable(init): - if size is None: + if sizes is None: raise ValueError('"varshape" cannot be None when data is a callable function.') - if isinstance(batch_size_or_mode, bm.NonBatchingMode): - return bm.Variable(init(size)) - elif isinstance(batch_size_or_mode, bm.BatchingMode): - new_shape = size[:batch_axis] + (batch_size_or_mode.batch_size,) + size[batch_axis:] - return bm.Variable(init(new_shape), batch_axis=batch_axis) - elif batch_size_or_mode in (None, False): - return bm.Variable(init(size)) - elif isinstance(batch_size_or_mode, int): - new_shape = size[:batch_axis] + (int(batch_size_or_mode),) + size[batch_axis:] - return bm.Variable(init(new_shape), batch_axis=batch_axis) + if isinstance(batch_or_mode, bm.NonBatchingMode): + data = bm.Variable(init(sizes)) + elif isinstance(batch_or_mode, bm.BatchingMode): + new_shape = sizes[:batch_axis] + (batch_or_mode.batch_size,) + sizes[batch_axis:] + data = bm.Variable(init(new_shape), batch_axis=batch_axis) + elif batch_or_mode in (None, False): + data = bm.Variable(init(sizes)) + elif isinstance(batch_or_mode, int): + new_shape = sizes[:batch_axis] + (int(batch_or_mode),) + sizes[batch_axis:] + data = bm.Variable(init(new_shape), batch_axis=batch_axis) else: - raise ValueError(f'Unknown batch_size_or_mode: {batch_size_or_mode}') + raise ValueError(f'Unknown batch_size_or_mode: {batch_or_mode}') else: - if size is not None: - if bm.shape(init) != size: - raise ValueError(f'The shape of "data" {bm.shape(init)} does not match with "var_shape" {size}') - if isinstance(batch_size_or_mode, bm.NonBatchingMode): - return bm.Variable(init) - elif isinstance(batch_size_or_mode, bm.BatchingMode): - return bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), - batch_size_or_mode.batch_size, + if sizes is not None: + if bm.shape(init) != sizes: + raise ValueError(f'The shape of "data" {bm.shape(init)} does not match with "var_shape" {sizes}') + if isinstance(batch_or_mode, bm.NonBatchingMode): + data = bm.Variable(init) + elif isinstance(batch_or_mode, bm.BatchingMode): + data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), + batch_or_mode.batch_size, axis=batch_axis), batch_axis=batch_axis) - elif batch_size_or_mode in (None, False): - return bm.Variable(init) - elif isinstance(batch_size_or_mode, int): - return bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), - int(batch_size_or_mode), + elif batch_or_mode in (None, False): + data = bm.Variable(init) + elif isinstance(batch_or_mode, int): + data = bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis), + int(batch_or_mode), axis=batch_axis), batch_axis=batch_axis) else: raise ValueError('Unknown batch_size_or_mode.') + return bm.sharding.partition_by_axname(data, axis_names) def noise( diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index 74dd01dcc..e4b57ff46 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -138,7 +138,7 @@ class ExponentialEuler(ODEIntegrator): >>> import brainpy as bp >>> import brainpy.math as bm >>> - >>> class HH(bp.dyn.NeuGroup): + >>> class HH(bp.NeuGroup): >>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., >>> gL=0.1, V_th=20., phi=5.0, name=None): >>> super(HH, self).__init__(size=size, name=name) @@ -211,7 +211,7 @@ class ExponentialEuler(ODEIntegrator): >>> import brainpy as bp >>> import brainpy.math as bm >>> - >>> class HH(bp.dyn.NeuGroup): + >>> class HH(bp.NeuGroup): >>> def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., >>> gL=0.1, V_th=20., phi=5.0, name=None): >>> super(HH, self).__init__(size=size, name=name) diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 167cdf738..8abbdf78e 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -17,12 +17,24 @@ from .base import BrainPyObject, ObjectTransform from .naming import get_stack_cache, cache_stack from .variables import Variable, VariableStack +from jax.sharding import Sharding +from jax._src.sharding_impls import UnspecifiedValue, UNSPECIFIED, AUTO __all__ = [ 'jit', ] +def _get_sharding(a): + pass + + +def _get_sharding_of_dyn_vars(dyn_vars: dict): + leaves, tree = jax.tree_util.tree_flatten(dyn_vars) + + + + def _seq_of_int(static_argnums): if static_argnums is None: static_argnums = () @@ -62,13 +74,12 @@ def __init__( abstracted_axes: Optional[Any] = None, name: Optional[str] = None, backend: Optional[str] = None, + in_shardings: Union[Sharding, UnspecifiedValue] = UNSPECIFIED, + out_shardings: Union[Sharding, UnspecifiedValue] = UNSPECIFIED, # deprecated dyn_vars: Dict[str, Variable] = None, child_objs: Dict[str, BrainPyObject] = None, - - # others - **kwargs ): super().__init__(name=name) @@ -92,7 +103,16 @@ def __init__( self._inline = inline self._keep_unused = keep_unused self._abstracted_axes = abstracted_axes - self._kwargs = kwargs + self._in_shardings = in_shardings + self._out_shardings = out_shardings + # if isinstance(in_shardings, UnspecifiedValue): + # pass + # else: + # self._in_shardings = (UNSPECIFIED, in_shardings) + # if isinstance(out_shardings, UnspecifiedValue): + # pass + # else: + # self._out_shardings = (AUTO, out_shardings) # transformation function self._transform = None @@ -103,18 +123,20 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): v._value = variable_data[key] out = self.fun(*args, **kwargs) changes = self._dyn_vars.dict_data() - return out, changes + return changes, out def __call__(self, *args, **kwargs): if jax.config.jax_disable_jit: return self.fun(*args, **kwargs) if self._transform is None: - self._dyn_vars = evaluate_dyn_vars(self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - **kwargs) + self._dyn_vars = evaluate_dyn_vars( + self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + **kwargs + ) self._transform = jax.jit( self._transform_function, static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), @@ -125,9 +147,10 @@ def __call__(self, *args, **kwargs): keep_unused=self._keep_unused, abstracted_axes=self._abstracted_axes, backend=self._backend, - **self._kwargs + in_shardings=self._in_shardings, + out_shardings=self._out_shardings, ) - out, changes = self._transform(self._dyn_vars.dict_data(), *args, **kwargs) + changes, out = self._transform(self._dyn_vars.dict_data(), *args, **kwargs) for key, v in self._dyn_vars.items(): v._value = changes[key] return out @@ -200,7 +223,6 @@ def jit( backend: Optional[str] = None, abstracted_axes: Optional[Any] = None, - # deprecated dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, @@ -422,7 +444,6 @@ def call_fun(self, *args, **kwargs): def _make_transform(fun, stack): - @wraps(fun) def _transform_function(variable_data: dict, *args, **kwargs): for key, v in stack.items(): diff --git a/brainpy/_src/math/sharding.py b/brainpy/_src/math/sharding.py new file mode 100644 index 000000000..74c39bba0 --- /dev/null +++ b/brainpy/_src/math/sharding.py @@ -0,0 +1,112 @@ +from functools import partial +from typing import Dict, Optional, Any, Union, Sequence + +import jax +import numpy as np +from jax._src.sharding_impls import UnspecifiedValue, UNSPECIFIED +from jax.sharding import PartitionSpec, Mesh, NamedSharding, Sharding + +from .ndarray import Array + +__all__ = [ + 'set', + 'get_sharding', + 'partition_by_axname', + 'partition_by_sharding', +] + +_mesh: Optional[Mesh] = None + + +def set( + mesh: Optional[Mesh] = None, + mesh_shape: Optional[Sequence[int]] = None, + mesh_axes: Optional[Sequence[str]] = None, +): + global _mesh + + if mesh_axes is not None: + assert mesh_axes is not None, 'Provide both "mesh_axes" and "mesh_shape".' + assert mesh is None, 'Provide either "mesh" or "mesh_axes" + "mesh_shape".' + assert len(mesh_axes) == len(mesh_shape) + mesh = Mesh(np.asarray(jax.devices()).reshape(*mesh_shape), axis_names=mesh_axes) + _mesh = mesh + else: + if mesh is not None: + _mesh = mesh + assert mesh_shape is None and mesh_axes is None, 'Provide either "mesh" or "mesh_axes" + "mesh_shape".' + else: + _mesh = None + + +def _device_put(x: Union[Array, jax.Array, np.ndarray], + named_shard: NamedSharding): + if isinstance(x, Array): + x.value = jax.device_put(x, device=named_shard) + return x + + +def get_sharding( + axis_names: Optional[Sequence[str]] = None, + mesh: Optional[Mesh] = None +) -> Union[UnspecifiedValue, NamedSharding]: + """Get sharding according to the given axes information. + + Args: + axis_names: list of str, or tuple of str. The name for each axis in the array. + mesh: Mesh. The given device mesh. + + Returns: + The instance of NamedSharding. + """ + if axis_names is None: + return UNSPECIFIED + if mesh is None: + mesh = _mesh + if mesh is None: + return UNSPECIFIED + else: + axis_names = [(name if name in mesh.axis_names else None) for name in axis_names] + return NamedSharding(mesh, PartitionSpec(*axis_names)) + + +def partition_by_axname( + x: Any, + axis_names: Optional[Sequence[str]] = None, + mesh: Optional[Mesh] = None +): + """Put the given arrays into the mesh devices. + + Args: + x: any. Any array. + axis_names: list of str, or tuple of str. The name for each axis in the array. + mesh: Mesh. The given device mesh. + + Returns: + The re-sharded arrays. + """ + if axis_names is None: + return x + if mesh is None: + if _mesh is None: + return x + mesh = _mesh + shard = get_sharding(axis_names, mesh) + if shard is None: + return x + else: + f = partial(_device_put, named_shard=shard) + return jax.tree_util.tree_map(f, x, is_leaf=lambda a: isinstance(a, Array)) + + +def partition_by_sharding( + x: Any, + sharding: Optional[Sharding] = None, +): + if sharding is None: + return x + else: + assert isinstance(sharding, Sharding) + f = partial(_device_put, named_shard=sharding) + return jax.tree_util.tree_map(f, x, is_leaf=lambda a: isinstance(a, Array)) + diff --git a/brainpy/_src/pnn/__init__.py b/brainpy/_src/pnn/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/brainpy/_src/pnn/__init__.py @@ -0,0 +1 @@ + diff --git a/brainpy/_src/psnn/neurons_abstract.py b/brainpy/_src/pnn/channels/__init__.py similarity index 100% rename from brainpy/_src/psnn/neurons_abstract.py rename to brainpy/_src/pnn/channels/__init__.py diff --git a/brainpy/_src/pnn/common.py b/brainpy/_src/pnn/common.py new file mode 100644 index 000000000..41feedcf5 --- /dev/null +++ b/brainpy/_src/pnn/common.py @@ -0,0 +1,143 @@ +from typing import Union, Callable, Optional, Sequence + +import brainpy.math as bm +from brainpy._src import initialize as init +from brainpy._src import tools +from brainpy._src.context import share +from brainpy._src.dynsys import DynamicalSystemNS +from brainpy._src.integrators import odeint +from brainpy._src.pnn.utils.axis_names import NEU_AXIS +from brainpy.check import is_initializer +from brainpy.types import Shape, ArrayType + +__all__ = [ + 'Leaky', + 'Integrator', +] + + +class Leaky(DynamicalSystemNS): + r"""Leaky Integrator Model. + + **Model Descriptions** + + This class implements a leaky model, in which its dynamics is + given by: + + .. math:: + + x(t + \Delta t) = \exp{-\Delta t/\tau} x(t) + I + + Args: + size: sequence of int, int. The size of the neuron group. + tau: float, ArrayType, Initializer, callable. Membrane time constant. + method: str. The numerical integration method. Default "exp_auto". + mode: Mode. The computing mode. Default None. + name: str. The group name. + """ + + supported_modes = (bm.TrainingMode, bm.NonBatchingMode) + + def __init__( + self, + size: Shape, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + tau: Union[float, ArrayType, Callable] = 10., + method: str = 'exp_auto', + mode: bm.Mode = None, + name: str = None, + init_var: bool = True + ): + super().__init__(mode=mode, name=name) + + # parameters + self.size = tools.to_size(size) + self.axis_names = axis_names + self.tau = init.parameter(tau, self.size, axis_names=axis_names) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, x, t): + return -x / self.tau + + def reset_state(self, batch_size=None): + self.x = init.variable_(bm.zeros, self.size, batch_size, axis_names=self.axis_names) + + def update(self, inp=None): + t = share.load('t') + dt = share.load('dt') + self.x.value = self.integral(self.x.value, t, dt) + if inp is not None: + self.x += inp + return self.x.value + + +class Integrator(DynamicalSystemNS): + r"""Integrator Model. + + This class implements an integrator model, in which its dynamics is + given by: + + .. math:: + + \tau \frac{dx}{dt} = - x(t) + I(t) + + where :math:`x` is the integrator value, and :math:`\tau` is the time constant. + + Args: + size: sequence of int, int. The size of the neuron group. + tau: float, ArrayType, Initializer, callable. Membrane time constant. + method: str. The numerical integration method. Default "exp_auto". + name: str. The group name. + mode: Mode. The computing mode. Default None. + x_initializer: ArrayType, Initializer, callable. The initializer of :math:`x`. + """ + + supported_modes = (bm.TrainingMode, bm.NonBatchingMode) + + def __init__( + self, + size: Shape, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + tau: Union[float, ArrayType, Callable] = 10., + x_initializer: Union[Callable, ArrayType] = init.ZeroInit(), + name: str = None, + mode: bm.Mode = None, + method: str = 'exp_auto', + init_var: bool = True, + ): + super().__init__(mode=mode, name=name) + + # parameters + self.size = tools.to_size(size) + self.axis_names = axis_names + self.tau = init.parameter(tau, self.size, axis_names=self.axis_names) + + # initializers + self._x_initializer = is_initializer(x_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I_ext): + return (-V + I_ext) / self.tau + + def reset_state(self, batch_size=None): + self.x = init.variable_(self._x_initializer, self.size, batch_size, axis_names=self.axis_names) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + self.x.value = self.integral(self.x.value, t, I_ext=x, dt=dt) + return self.x.value + diff --git a/brainpy/_src/pnn/delay.py b/brainpy/_src/pnn/delay.py new file mode 100644 index 000000000..616035186 --- /dev/null +++ b/brainpy/_src/pnn/delay.py @@ -0,0 +1,538 @@ +""" +Delay variable. +""" + +import numbers +from dataclasses import dataclass +from typing import Union, Callable, Optional, Dict, Sequence + +import jax +import jax.numpy as jnp +import numpy as np +from jax.lax import stop_gradient + +from brainpy import check +from brainpy import math as bm, tools +from brainpy._src.context import share +from brainpy._src.dynsys import DynamicalSystemNS +from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE +from brainpy._src.pnn.utils.axis_names import NEU_AXIS, TIME_AXIS +from brainpy.check import jit_error +from brainpy.types import ArrayType +from .mixin import ParamDesc + +__all__ = [ + 'TargetDelay', + 'DataDelay', +] + + +@dataclass(frozen=True) +class DelayDesc: + time: Optional[Union[int, float]] = None + dtype: Optional[type] = None + init: Optional[Union[numbers.Number, ArrayType, Callable]] = None + target: Optional[bm.Variable] = None + + +class Delay(DynamicalSystemNS, ParamDesc): + data: Optional[bm.Variable] + + def __init__( + self, + # delay time + time: Optional[Union[int, float]] = None, + + # delay init + init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, + + # delay method + method: Optional[str] = None, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # delay method + if method is None: + if self.mode.is_parent_of(bm.NonBatchingMode): + method = ROTATE_UPDATE + elif self.mode.is_parent_of(bm.TrainingMode): + method = CONCAT_UPDATE + else: + method = ROTATE_UPDATE + assert method in [ROTATE_UPDATE, CONCAT_UPDATE] + self.method = method + + # delay length + if time is None: + length = 0 + time = 0. + elif isinstance(time, (int, float)): + length = int(time / bm.get_dt()) + else: + raise TypeError('time must be a int or float or None.') + assert isinstance(length, int) + self.max_length = length + self.max_time = time + + # delay data + if init is not None: + assert isinstance(init, (numbers.Number, bm.Array, jax.Array, Callable)) + self._init = init + + # other info + self._registered_entries = dict() + + def register_entry( + self, + entry: str, + delay_time: Optional[Union[float, bm.Array, Callable]], + ) -> 'Delay': + """Register an entry to access the data. + + Args: + entry: str. The entry to access the delay data. + delay_time: The delay time of the entry (can be a float). + + Returns: + Return the self. + """ + raise NotImplementedError + + def at(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. + + Args: + entry: str. The entry to access the data. + *indices: The slicing indices. + + Returns: + The data. + """ + raise NotImplementedError + + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. + + Parameters + ---------- + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + raise NotImplementedError() + + +class TargetDelay(Delay): + """Delay variable which has a fixed delay length. + + The data in this delay variable is arranged as:: + + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + Args: + target: Variable. The delay target. + axis_names: sequence of str. The name for each axis. + time: int, float. The delay time. + init: Any. The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + entries: optional, dict. The delay access entries. + name: str. The delay name. + method: str. The method used for updating delay. Default None. + mode: Mode. The computing mode. Default None. + + """ + + not_desc_params = ('time', 'entries') + + def __init__( + self, + + # delay target + target: bm.Variable, + axis_names: Optional[Sequence[str]] = None, + + # delay time + time: Optional[Union[int, float]] = None, + + # delay init + init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, + + # delay access entry + entries: Optional[Dict] = None, + + # delay method + method: Optional[str] = None, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(time=time, init=init, method=method, name=name, mode=mode) + + # target + if not isinstance(target, bm.Variable): + raise ValueError(f'Must be an instance of brainpy.math.Variable. But we got {type(target)}') + + if axis_names is not None: + if len(axis_names) == target.ndim: + axis_names = list(axis_names) + elif len(axis_names) + 1 == target.ndim and isinstance(self.mode, bm.BatchingMode): + axis_names = list(axis_names) + axis_names.insert(0, NEU_AXIS) + else: + raise ValueError + self.target_axis_names = axis_names + if axis_names is not None: + axis_names = list(axis_names) + axis_names.insert(0, TIME_AXIS) + self.data_axis_names = axis_names + + if self.mode.is_child_of(bm.BatchingMode): + assert target.batch_axis is not None + target = bm.sharding.partition_by_axname(target, self.target_axis_names) + self.target = target + + # delay data + self._init = init + if self.max_length > 0: + self._init_data(self.max_length) + else: + self.data = None + + # other info + if entries is not None: + for entry, value in entries.items(): + self.register_entry(entry, value) + + def register_entry( + self, + entry: str, + delay_time: Optional[Union[float, bm.Array, Callable]], + ) -> 'Delay': + """Register an entry to access the data. + + Args: + entry: str. The entry to access the delay data. + delay_time: The delay time of the entry (can be a float). + + Returns: + Return the self. + """ + if entry in self._registered_entries: + raise KeyError(f'Entry {entry} has been registered.') + + if delay_time is None: + delay_step = None + delay_time = 0. + elif callable(delay_time): + delay_time = bm.as_jax(delay_time(self.delay_target_shape)) + delay_step = jnp.asarray(delay_time / bm.get_dt(), dtype=bm.get_int()) + elif isinstance(delay_time, float): + delay_step = int(delay_time / bm.get_dt()) + else: + delay_step = jnp.asarray(bm.as_jax(delay_time) / bm.get_dt(), dtype=bm.get_int()) + + # delay steps + if delay_step is None: + delay_type = 'none' + elif isinstance(delay_step, int): + delay_type = 'homo' + elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)): + if delay_step.size == 1 and delay_step.ndim == 0: + delay_type = 'homo' + else: + delay_type = 'heter' + delay_step = bm.Array(delay_step) + elif callable(delay_step): + delay_step = delay_step(self.delay_target_shape) + delay_type = 'heter' + else: + raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' + f'integer, array of integers, callable function, brainpy.init.Initializer.') + if delay_type == 'heter': + if delay_step.dtype not in [jnp.int32, jnp.int64]: + raise ValueError('Only support delay steps of int32, int64. If your ' + 'provide delay time length, please divide the "dt" ' + 'then provide us the number of delay steps.') + if self.delay_target_shape[0] != delay_step.shape[0]: + raise ValueError(f'Shape is mismatched: {self.delay_target_shape[0]} != {delay_step.shape[0]}') + if delay_type == 'heter': + max_delay_step = int(max(delay_step)) + elif delay_type == 'homo': + max_delay_step = delay_step + else: + max_delay_step = None + + # delay variable + if max_delay_step is not None: + if self.max_length < max_delay_step: + self._init_data(max_delay_step) + self.max_length = max_delay_step + self.max_time = delay_time + self._registered_entries[entry] = delay_step + return self + + def at(self, entry: str, *indices) -> bm.Array: + """Get the data at the given entry. + + Args: + entry: str. The entry to access the data. + *indices: The slicing indices. + + Returns: + The data. + """ + assert isinstance(entry, str), 'entry should be a string for describing the ' + if entry not in self._registered_entries: + raise KeyError(f'Does not find delay entry "{entry}".') + delay_step = self._registered_entries[entry] + if delay_step is None: + return self.target.value + else: + if self.data is None: + return self.target.value + else: + if isinstance(delay_step, slice): + return self.retrieve(delay_step, *indices) + elif np.ndim(delay_step) == 0: + return self.retrieve(delay_step, *indices) + else: + if len(indices) == 0 and len(delay_step) == self.target.shape[0]: + indices = (jnp.arange(delay_step.size),) + return self.retrieve(delay_step, *indices) + + @property + def delay_target_shape(self): + """The data shape of the delay target.""" + return self.target.shape + + def __repr__(self): + name = self.__class__.__name__ + return (f'{name}(num_delay_step={self.max_length}, ' + f'delay_target_shape={self.delay_target_shape}, ' + f'update_method={self.method})') + + def _check_delay(self, delay_len): + raise ValueError(f'The request delay length should be less than the ' + f'maximum delay {self.max_length}. ' + f'But we got {delay_len}') + + def retrieve(self, delay_step, *indices): + """Retrieve the delay data according to the delay length. + + Parameters + ---------- + delay_step: int, ArrayType + The delay length used to retrieve the data. + """ + assert delay_step is not None + if check.is_checking(): + jit_error(bm.any(delay_step > self.max_length), self._check_delay, delay_step) + + if self.method == ROTATE_UPDATE: + i = share.load('i') + delay_idx = (i + delay_step) % (self.max_length + 1) + delay_idx = stop_gradient(delay_idx) + + elif self.method == CONCAT_UPDATE: + delay_idx = delay_step + + else: + raise ValueError(f'Unknown updating method "{self.method}"') + + # the delay index + if hasattr(delay_idx, 'dtype') and not jnp.issubdtype(delay_idx.dtype, jnp.integer): + raise ValueError(f'"delay_len" must be integer, but we got {delay_idx}') + indices = (delay_idx,) + tuple(indices) + + # the delay data + return self.data[indices] + + def update( + self, + latest_value: Optional[Union[bm.Array, jax.Array]] = None + ) -> None: + """Update delay variable with the new data. + """ + if self.data is not None: + # get the latest target value + if latest_value is None: + latest_value = self.target.value + + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + i = share.load('i') + idx = bm.as_jax((i - 1) % (self.max_length + 1)) + self.data[idx] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.max_length >= 2: + self.data.value = bm.vstack([latest_value, self.data[1:]]) + else: + self.data[0] = latest_value + + def reset_state(self, batch_size: int = None): + """Reset the delay data. + """ + # initialize delay data + if self.data is not None: + self._init_data(self.max_length, batch_size) + + def _init_data(self, length: int, batch_size: int = None): + if batch_size is not None: + if self.target.batch_size != batch_size: + raise ValueError(f'The batch sizes of delay variable and target variable differ ' + f'({self.target.batch_size} != {batch_size}). ' + 'Please reset the target variable first, because delay data ' + 'depends on the target variable. ') + + if self.target.batch_axis is None: + batch_axis = None + else: + batch_axis = self.target.batch_axis + 1 + + f = jax.jit(jnp.zeros, static_argnums=0, static_argnames='dtype', + out_shardings=bm.sharding.get_sharding(self.data_axis_names)) + data = f((length + 1,) + self.target.shape, dtype=self.target.dtype) + self.data = bm.Variable(data, batch_axis=batch_axis) + # update delay data + self.data[0] = self.target.value + if isinstance(self._init, (bm.Array, jax.Array, numbers.Number)): + self.data[1:] = self._init + elif callable(self._init): + self.data[1:] = self._init((length,) + self.target.shape, + dtype=self.target.dtype) + + +class DataDelay(TargetDelay): + """Delay variable which has a fixed delay length. + + The data in this delay variable is arranged as:: + + delay = 0 [ data + delay = 1 data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + + Args: + size: int, sequence of int. The delay target size. + axis_names: sequence of str. The name for each axis. + time: optional, int, float. The delay time. Default is None. + dtype: type. The data type. + init: Any. The delay data. It can be a Python number, like float, int, boolean values. + It can also be arrays. Or a callable function or instance of ``Connector``. + Note that ``initial_delay_data`` should be arranged as the following way:: + + delay = 1 [ data + delay = 2 data + ... .... + ... .... + delay = length-1 data + delay = length data ] + entries: optional, dict. The delay access entries. + name: str. The delay name. + method: str. The method used for updating delay. Default None. + mode: Mode. The computing mode. Default None. + + """ + + not_desc_params = ('time', 'entries') + + def __init__( + self, + + # delay info + size: Union[int, Sequence[int]], + axis_names: Optional[Sequence[str]] = None, + dtype: Optional[type] = None, + + # delay time + time: Optional[Union[int, float]] = None, + + # delay init + init: Optional[Union[numbers.Number, bm.Array, jax.Array, Callable]] = None, + + # delay access entry + entries: Optional[Dict] = None, + + # delay method + method: Optional[str] = None, + + # others + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + size = tools.to_size(size) + mode = mode if mode is not None else bm.get_mode() + if axis_names is not None: + assert len(size) == len(axis_names) + if isinstance(mode, bm.BatchingMode): + batch_axis = 0 + size = (mode.batch_size,) + size + else: + batch_axis = None + + target = bm.Variable(bm.zeros(size, dtype=dtype), batch_axis=batch_axis) + if init is None: + pass + elif isinstance(init, (bm.Array, jax.Array, numbers.Number)): + target[:] = self._init + elif callable(self._init): + target[:] = self._init(size, dtype=dtype) + else: + raise ValueError + + super().__init__(target=target, + axis_names=axis_names, + time=time, + init=init, + entries=entries, + method=method, + name=name, + mode=mode) + + def update( + self, + latest_value: Union[bm.Array, jax.Array] + ) -> None: + """Update delay variable with the new data. + """ + latest_value = bm.sharding.partition_by_axname(latest_value, self.target_axis_names) + + # get the latest target value + self.target.value = latest_value + + if self.data is not None: + # update the delay data at the rotation index + if self.method == ROTATE_UPDATE: + i = share.load('i') + idx = bm.as_jax((i - 1) % (self.max_length + 1)) + self.data[idx] = latest_value + + # update the delay data at the first position + elif self.method == CONCAT_UPDATE: + if self.max_length >= 2: + self.data.value = bm.vstack([latest_value, self.data[1:]]) + else: + self.data[0] = latest_value diff --git a/brainpy/_src/pnn/mixin.py b/brainpy/_src/pnn/mixin.py new file mode 100644 index 000000000..6fbc651d4 --- /dev/null +++ b/brainpy/_src/pnn/mixin.py @@ -0,0 +1,54 @@ +from typing import Optional, Sequence + +from .utils import DelayedInit +from brainpy._src import tools, math as bm + + +__all__ = [ + 'MixIn', + 'ParamDesc', + 'AlignPost', +] + + +class MixIn(object): + pass + + +class ParamDesc(MixIn): + """Parameter description MixIn. + + This mixin enables the subclass has a classmethod ``desc``, which + produces an instance of :py:class:`~.DelayedInit`. + """ + + not_desc_params: Optional[Sequence[str]] = None + + @classmethod + def desc(cls, *args, **kwargs) -> DelayedInit: + # cls_args = list(inspect.signature(cls.__init__).parameters.values())[1:] + # names = [arg.name for arg in cls_args] + # defaults = [arg.default for arg in cls_args] + if cls.not_desc_params is not None: + repr_kwargs = {k: v for k, v in kwargs.items() if k not in cls.not_desc_params} + else: + repr_kwargs = {k: v for k, v in kwargs.items()} + for k in tuple(repr_kwargs.keys()): + if isinstance(repr_kwargs[k], bm.Variable): + repr_kwargs[k] = id(repr_kwargs[k]) + repr_args = tools.repr_dict(repr_kwargs) + if len(args): + repr_args = f"{', '.join([repr(arg) for arg in args])}, {repr_args}" + return DelayedInit(cls, f'{cls.__name__}({repr_args})', *args, **kwargs) + + +class AlignPost(MixIn): + """Align post MixIn. + + This class provides a ``add_current()`` function for + add external currents. + """ + def add_current(self, *args, **kwargs): + raise NotImplementedError + + diff --git a/brainpy/_src/pnn/neurons/__init__.py b/brainpy/_src/pnn/neurons/__init__.py new file mode 100644 index 000000000..92ee19555 --- /dev/null +++ b/brainpy/_src/pnn/neurons/__init__.py @@ -0,0 +1,5 @@ + +from .base import * +from .lif import * + + diff --git a/brainpy/_src/pnn/neurons/_docs.py b/brainpy/_src/pnn/neurons/_docs.py new file mode 100644 index 000000000..6bf82195d --- /dev/null +++ b/brainpy/_src/pnn/neurons/_docs.py @@ -0,0 +1,32 @@ +pneu_doc = ''' + size: int, or sequence of int. The neuronal population size. + axis_names: sequence of str. The + keep_size: bool. Keep the neuron group size. + mode: Mode. The computing mode. + name: str. The group name. +'''.strip() + +dpneu_doc = ''' + spk_fun: callable. The spike activation function. + detach_spk: bool. + method: str. The numerical integration method. + spk_type: The spike data type. +'''.strip() + +ref_doc = ''' + tau_ref: float, ArrayType, callable. Refractory period length (ms). + has_ref_var: bool. Whether has the refractory variable. Default is ``False``. +'''.strip() + +lif_doc = ''' + V_rest: float, ArrayType, callable. Resting membrane potential. + V_reset: float, ArrayType, callable. Reset potential after spike. + V_th: float, ArrayType, callable. Threshold potential of spike. + R: float, ArrayType, callable. Membrane resistance. + tau: float, ArrayType, callable. Membrane time constant. + V_initializer: ArrayType, callable. The initializer of membrane potential. +'''.strip() + + +ltc_doc = 'with liquid time-constant' + diff --git a/brainpy/_src/pnn/neurons/base.py b/brainpy/_src/pnn/neurons/base.py new file mode 100644 index 000000000..f749c7a45 --- /dev/null +++ b/brainpy/_src/pnn/neurons/base.py @@ -0,0 +1,125 @@ +from typing import Sequence, Union, Callable, Any, Optional, Dict + +import brainpy.math as bm +from brainpy._src.dynsys import NeuGroupNS +from brainpy._src.initialize import (parameter, + variable_) +from brainpy.check import is_callable + +from brainpy._src.pnn.utils import NEU_AXIS +from brainpy._src.pnn.synapses.syn_output import PSynOut +from ._docs import pneu_doc, dpneu_doc + + +__all__ = [ + 'PNeuGroup', + 'DPNeuGroup', +] + + +class PNeuGroup(NeuGroupNS): + """Parallelizable Neuron Group. + + Args: + {pneu} + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + ): + super().__init__(size=size, + mode=mode, + keep_size=keep_size, + name=name) + + # axis names for parallelization + self.axis_names = axis_names + if axis_names is not None: + if len(axis_names) != len(self.varshape): + raise ValueError(f'Except len(varshape) == len(axis_names), ' + f'but got {len(self.varshape)} != {len(axis_names)}.') + + # the post updates used for computing + self.pre_updates: Dict[str, Callable] = bm.node_dict() + self.post_updates: Dict[str, Callable] = bm.node_dict() + + # outputs + self.cur_outputs: Dict[str, PSynOut] = bm.node_dict() + + def sharding_param(self, param, shape=None, axis_names=None): + """Sharding parameters across the default given devices. """ + if shape is None: + shape = self.varshape + if axis_names is None: + axis_names = self.axis_names + return parameter(param, sizes=shape, allow_none=False, axis_names=axis_names) + + def sharding_variable(self, var, batch_or_mode, shape=None, axis_names=None): + """Sharding variables across the given devices.""" + if shape is None: + shape = self.varshape + if axis_names is None: + axis_names = self.axis_names + return variable_(var, sizes=shape, batch_or_mode=batch_or_mode, + axis_names=axis_names, batch_axis_name='batch') + + def __call__(self, *args, **kwargs): + for model in tuple(self.pre_updates.values()): + model() + ret = super().__call__(*args, **kwargs) + for model in tuple(self.post_updates.values()): + model(ret) + return ret + + +PNeuGroup.__doc__ = PNeuGroup.__doc__.format(pneu=pneu_doc) + + +class DPNeuGroup(PNeuGroup): + """Differentiable and Parallelizable Neuron Group. + + Args: + {pneu} + {dpneu} + """ + + supported_modes = (bm.TrainingMode, bm.NonBatchingMode) + + def __init__( + self, + size: Union[int, Sequence[int]], + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + keep_size: bool = False, + mode: bm.Mode = None, + name: str = None, + + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + detach_spk: bool = False, + method: str = 'exp_auto', + spk_type: Any = None, + ): + super().__init__(size=size, + mode=mode, + keep_size=keep_size, + name=name, + axis_names=axis_names) + + self.spk_fun = is_callable(spk_fun) + self.detach_spk = detach_spk + self.method = method + self._spk_type = spk_type + + @property + def spk_type(self): + if self._spk_type is None: + return bm.float_ if isinstance(self.mode, bm.TrainingMode) else bm.bool_ + else: + return self._spk_type + + +DPNeuGroup.__doc__ = DPNeuGroup.__doc__.format(pneu=pneu_doc, dpneu=dpneu_doc) diff --git a/brainpy/_src/pnn/neurons/hh.py b/brainpy/_src/pnn/neurons/hh.py new file mode 100644 index 000000000..e69de29bb diff --git a/brainpy/_src/pnn/neurons/lif.py b/brainpy/_src/pnn/neurons/lif.py new file mode 100644 index 000000000..e6a1c36ee --- /dev/null +++ b/brainpy/_src/pnn/neurons/lif.py @@ -0,0 +1,315 @@ +from functools import partial +from typing import Union, Callable, Optional, Sequence, Any + +from jax.lax import stop_gradient + +import brainpy.math as bm +from brainpy._src.context import share +from brainpy._src.initialize import (ZeroInit, Initializer) +from brainpy._src.integrators import odeint +from brainpy._src.pnn.utils.axis_names import NEU_AXIS +from brainpy.check import is_initializer +from brainpy.types import Shape, ArrayType +from ._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc +from .base import DPNeuGroup + +__all__ = [ + 'LIF', + 'LIFLtc', + 'LIFRef', + 'LIFRefLtc', +] + + +class IF(DPNeuGroup): + pass + + +class LIFLtc(DPNeuGroup): + r"""Leaky integrate-and-fire neuron model %s. + + The formal equations of a LIF model [1]_ is given by: + + .. math:: + + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} + + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + and :math:`I` is the time-variant synaptic inputs. + + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + + Args: + %s + %s + %s + + """ + + def __init__( + self, + size: Shape, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + init_var: bool = True, + + # neuron parameters + V_rest: Union[float, ArrayType, Initializer, Callable] = 0., + V_reset: Union[float, ArrayType, Initializer, Callable] = -5., + V_th: Union[float, ArrayType, Initializer, Callable] = 20., + R: Union[float, ArrayType, Initializer, Callable] = 1., + tau: Union[float, ArrayType, Initializer, Callable] = 10., + V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + ): + # initialization + super().__init__(size=size, + name=name, + keep_size=keep_size, + mode=mode, + axis_names=axis_names, + spk_fun=spk_fun, + detach_spk=detach_spk, + method=method, + spk_type=spk_type) + + # parameters + self.V_rest = self.sharding_param(V_rest) + self.V_reset = self.sharding_param(V_reset) + self.V_th = self.sharding_param(V_th) + self.tau = self.sharding_param(tau) + self.R = self.sharding_param(R) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def derivative(self, V, t, I): + for out in self.cur_outputs.values(): + I += out(V) + return (-V + self.V_rest + self.R * I) / self.tau + + def reset_state(self, batch_size=None): + self.V = self.sharding_variable(self._V_initializer, batch_size) + self.spike = self.sharding_variable(partial(bm.zeros, dtype=self.spk_type), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + + # spike, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + + self.V.value = V + self.spike.value = spike + return spike + + +class LIF(LIFLtc): + def derivative(self, V, t, I): + return (-V + self.V_rest + self.R * I) / self.tau + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_outputs.values(): + x += out(self.V.value) + super().update(x) + + +LIF.__doc__ = LIFLtc.__doc__ % ('', lif_doc, pneu_doc, dpneu_doc) +LIFLtc.__doc__ = LIFLtc.__doc__ % (ltc_doc, lif_doc, pneu_doc, dpneu_doc) + + +class LIFRefLtc(LIFLtc): + r"""Leaky integrate-and-fire neuron model %s which has refractory periods. + + The formal equations of a LIF model [1]_ is given by: + + .. math:: + + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) \\ + \text{after} \quad V(t) \gt V_{th}, V(t) = V_{reset} \quad + \text{last} \quad \tau_{ref} \quad \text{ms} + + where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting + membrane potential, :math:`V_{reset}` is the reset membrane potential, + :math:`V_{th}` is the spike threshold, :math:`\tau` is the time constant, + :math:`\tau_{ref}` is the refractory time period, + and :math:`I` is the time-variant synaptic inputs. + + .. [1] Abbott, Larry F. "Lapicque’s introduction of the integrate-and-fire model + neuron (1907)." Brain research bulletin 50, no. 5-6 (1999): 303-304. + + Args: + %s + %s + %s + %s + + """ + + def __init__( + self, + size: Shape, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + keep_size: bool = False, + mode: Optional[bm.Mode] = None, + spk_fun: Callable = bm.surrogate.InvSquareGrad(), + spk_type: Any = None, + detach_spk: bool = False, + method: str = 'exp_auto', + name: Optional[str] = None, + init_var: bool = True, + + # old neuron parameter + V_rest: Union[float, ArrayType, Initializer, Callable] = 0., + V_reset: Union[float, ArrayType, Initializer, Callable] = -5., + V_th: Union[float, ArrayType, Initializer, Callable] = 20., + R: Union[float, ArrayType, Initializer, Callable] = 1., + tau: Union[float, ArrayType, Initializer, Callable] = 10., + V_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + + # new neuron parameter + tau_ref: Optional[Union[float, ArrayType, Initializer, Callable]] = None, + has_ref_var: bool = False, + ): + # initialization + super().__init__( + size=size, + name=name, + keep_size=keep_size, + mode=mode, + method=method, + axis_names=axis_names, + spk_fun=spk_fun, + detach_spk=detach_spk, + spk_type=spk_type, + + init_var=False, + + V_rest=V_rest, + V_reset=V_reset, + V_th=V_th, + R=R, + tau=tau, + V_initializer=V_initializer, + ) + + # parameters + self.has_ref_var = has_ref_var + self.tau_ref = self.sharding_param(tau_ref) + + # initializers + self._V_initializer = is_initializer(V_initializer) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + if init_var: + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + super().reset_state(batch_size) + self.t_last_spike = self.sharding_variable(bm.ones, batch_size) + self.t_last_spike.fill_(-1e7) + if self.has_ref_var: + self.refractory = self.sharding_variable(partial(bm.zeros, dtype=bool), batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + x = 0. if x is None else x + + # integrate membrane potential + V = self.integral(self.V.value, t, x, dt) + + # refractory + refractory = (t - self.t_last_spike) <= self.tau_ref + if isinstance(self.mode, bm.TrainingMode): + refractory = stop_gradient(refractory) + V = bm.where(refractory, self.V.value, V) + + # spike, refractory, spiking time, and membrane potential reset + if isinstance(self.mode, bm.TrainingMode): + spike = self.spk_fun(V - self.V_th) + spike_no_grad = stop_gradient(spike) if self.detach_spk else spike + V += (self.V_reset - V) * spike_no_grad + spike_ = spike_no_grad > 0. + # will be used in other place, like Delta Synapse, so stop its gradient + if self.has_ref_var: + self.refractory.value = stop_gradient(bm.logical_or(refractory, spike_).value) + t_last_spike = stop_gradient(bm.where(spike_, t, self.t_last_spike.value)) + + else: + spike = V >= self.V_th + V = bm.where(spike, self.V_reset, V) + if self.has_ref_var: + self.refractory.value = bm.logical_or(refractory, spike) + t_last_spike = bm.where(spike, t, self.t_last_spike.value) + self.V.value = V + self.spike.value = spike + self.t_last_spike.value = t_last_spike + return spike + + +class LIFRef(LIFRefLtc): + def derivative(self, V, t, I): + return (-V + self.V_rest + self.R * I) / self.tau + + def update(self, x=None): + x = 0. if x is None else x + for out in self.cur_outputs.values(): + x += out(self.V.value) + super().update(x) + + +LIFRef.__doc__ = LIFRefLtc.__doc__ % ('', lif_doc, pneu_doc, dpneu_doc, ref_doc) +LIFRefLtc.__doc__ = LIFRefLtc.__doc__ % (ltc_doc, lif_doc, pneu_doc, dpneu_doc, ref_doc) + + +class ExpIF(DPNeuGroup): + pass + + +class AdExIF(DPNeuGroup): + pass + + +class QuaIF(DPNeuGroup): + pass + + +class AdQuaIF(DPNeuGroup): + pass + + +class GIF(DPNeuGroup): + pass + diff --git a/brainpy/_src/pnn/synapses/__init__.py b/brainpy/_src/pnn/synapses/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/brainpy/_src/pnn/synapses/projections.py b/brainpy/_src/pnn/synapses/projections.py new file mode 100644 index 000000000..fe8990af3 --- /dev/null +++ b/brainpy/_src/pnn/synapses/projections.py @@ -0,0 +1,179 @@ +from typing import Optional, Callable, Union + +from brainpy import math as bm +from brainpy._src.dynsys import DynamicalSystemNS, DynamicalSystem +from brainpy._src.pnn.delay import DataDelay, Delay +from brainpy._src.pnn.neurons.base import PNeuGroup +from brainpy._src.pnn.utils import DelayedInit +from .syn_output import PSynOut + +__all__ = [ + 'ProjectionAlignPre', + 'ProjectionAlignPost', +] + + +class _AlignPre(DynamicalSystemNS): + def __init__(self, syn, delay=None): + super().__init__() + self.syn = syn + self.delay = delay + + def update(self, x): + if self.delay is None: + return x >> self.syn + else: + return x >> self.syn >> self.delay + + +class ProjectionAlignPre(DynamicalSystemNS): + """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. + + Args: + pre: The pre-synaptic neuron group. + syn: The synaptic dynamics. + delay: The synaptic delay. + comm: The synaptic communication. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + def __init__( + self, + pre: PNeuGroup, + syn: DelayedInit[DynamicalSystem], + delay: Union[None, int, float, DelayedInit[Delay]], + comm: Callable, + out: PSynOut, + post: PNeuGroup, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + assert isinstance(pre, PNeuGroup) + assert isinstance(post, PNeuGroup) + assert callable(comm) + assert isinstance(out, PSynOut) + assert isinstance(syn, DelayedInit) + self.pre = pre + self.post = post + self.comm = comm + + # synapse and delay initialization + self._syn_id = syn._identifier + delay_time = None + if self._syn_id not in pre.post_updates: + syn_cls = syn() + if delay is None: + delay_cls = DataDelay(pre.varshape, + axis_names=pre.axis_names, + mode=pre.mode) + elif isinstance(delay, (int, float)): + delay_time = delay + delay_cls = DataDelay(pre.varshape, + axis_names=pre.axis_names, + mode=pre.mode) + elif isinstance(delay, DelayedInit): + delay_time = delay.kwargs.get('time', None) + delay_cls = delay() + else: + raise TypeError + pre.post_updates[self._syn_id] = _AlignPre(syn_cls, delay_cls) + delay_cls = pre.post_updates[self._syn_id].delay + if delay_cls is not None: + delay_cls.register_entry(self.name, delay_time) + + # output initialization + post.cur_outputs[self.name] = out + + def update(self): + current = self.comm(self.pre.post_updates[self._syn_id].delay.at(self.name)) + self.post.cur_outputs[self.name].bind_cond(current) + return current + + +class _AlignPost(DynamicalSystemNS): + def __init__(self, syn, out): + super().__init__() + self.syn = syn + self.out = out + + def update(self, *args, **kwargs): + self.out.bind_cond(self.syn(*args, **kwargs)) + + +class ProjectionAlignPost(DynamicalSystemNS): + """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. + + Args: + pre: The pre-synaptic neuron group. + delay: The synaptic delay. + comm: The synaptic communication. + syn: The synaptic dynamics. + out: The synaptic output. + post: The post-synaptic neuron group. + name: str. The projection name. + mode: Mode. The computing mode. + """ + def __init__( + self, + pre: PNeuGroup, + delay: Union[None, int, float, DelayedInit[Delay]], + comm: Callable, + syn: DelayedInit[DynamicalSystem], + out: DelayedInit[PSynOut], + post: PNeuGroup, + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + ): + super().__init__(name=name, mode=mode) + + # synaptic models + assert isinstance(pre, PNeuGroup) + assert isinstance(post, PNeuGroup) + assert isinstance(syn, DelayedInit) + assert isinstance(out, DelayedInit) + assert callable(comm) + self.pre = pre + self.post = post + self.comm = comm + + # delay initialization + self._delay_repr = '_*_align_pre_spk_delay_*_' + delay_time = None + if self._delay_repr not in self.pre.post_updates: + if delay is None: + delay_cls = DataDelay(pre.varshape, + axis_names=pre.axis_names, + mode=pre.mode) + elif isinstance(delay, (int, float)): + delay_time = delay + delay_cls = DataDelay(pre.varshape, + axis_names=pre.axis_names, + mode=pre.mode) + elif isinstance(delay, DelayedInit): + delay_time = delay.kwargs.get('time', None) + delay_cls = delay() + else: + raise TypeError + self.pre.post_updates[self._delay_repr] = delay_cls + delay_cls = pre.post_updates[self._delay_repr] + if delay_cls is not None: + delay_cls.register_entry(self.name, delay_time) + + # synapse and output initialization + self._post_repr = f'{syn._identifier} // {out._identifier}' + if self._post_repr not in self.post.pre_updates: + syn_cls = syn() + out_cls = out() + self.post.cur_outputs[self.name] = out_cls + self.post.pre_updates[self._post_repr] = _AlignPost(syn_cls, out_cls) + + def update(self): + current = self.comm(self.pre.post_updates[self._delay_repr].at(self.name)) + self.post.pre_updates[self._post_repr].syn.add_current(current) # synapse post current + return current + diff --git a/brainpy/_src/pnn/synapses/syn_comm.py b/brainpy/_src/pnn/synapses/syn_comm.py new file mode 100644 index 000000000..96319ca5d --- /dev/null +++ b/brainpy/_src/pnn/synapses/syn_comm.py @@ -0,0 +1,322 @@ +from typing import Optional, Union, Callable, Sequence + +from brainpy import math as bm +from brainpy._src import connect, initialize as init +from brainpy._src.dynsys import DynamicalSystemNS +from brainpy._src.pnn.utils import POST_AXIS, PRE_AXIS +from brainpy.types import ArrayType + +__all__ = [ + 'All2allMM', + 'One2oneMM', + 'DenseMM', + 'CsrMM', +] + + +class SynComm(DynamicalSystemNS): + pass + + +class All2allMM(SynComm): + """Synaptic matrix multiplication with All2All connections. + + Args: + num_pre: int. The number of neurons in the presynaptic neuron group. + num_post: int. The number of neurons in the postsynaptic neuron group. + weight: The synaptic weights. + axis_names: sequence of str. The name for each axis. + include_self: bool. Whether connect the neuron with at the same position. + mode: Mode. The computing mode. + name: str. The object name. + """ + + def __init__( + self, + num_pre: int, + num_post: int, + weight: Union[float, ArrayType, Callable], + axis_names: Optional[Sequence[str]] = (PRE_AXIS, POST_AXIS), + include_self: bool = True, + + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(mode=mode, name=name) + + self.num_pre = num_pre + self.num_post = num_post + self.include_self = include_self + self.axis_names = axis_names + + self.weight = init.parameter(weight, (self.num_pre, self.num_post), axis_names=axis_names) + if isinstance(self.mode, bm.TrainingMode): + self.weight = bm.TrainVar(self.weight) + + def update(self, pre_val): + if bm.ndim(self.weight) == 0: # weight is a scalar + if isinstance(self.mode, bm.BatchingMode): + assert pre_val.ndim == 2 + post_val = bm.sum(pre_val, keepdims=True, axis=1) + else: + assert pre_val.ndim == 1 + post_val = bm.sum(pre_val) + if not self.include_self: + if self.num_pre == self.num_post: + post_val = post_val - pre_val + elif self.num_pre > self.num_post: + val = pre_val[:self.num_post] + post_val = post_val - val + else: + val = bm.concatenate([pre_val, bm.zeros(self.num_post - self.num_pre)]) + post_val = post_val - val + post_val = self.weight * post_val + + else: # weight is a matrix + if not self.include_self: + post_val = pre_val @ bm.fill_diagonal(self.weight, 0., inplace=False) + else: + post_val = pre_val @ self.weight + return post_val + + +class One2oneMM(SynComm): + """Synaptic matrix multiplication with One2One connection. + + Args: + num: int. The number of neurons. + weight: The synaptic weight. + axis_names: The axis names. + mode: The computing mode. + name: The object name. + + """ + + def __init__( + self, + num: int, + weight: Union[float, ArrayType, Callable], + axis_names: Optional[Sequence[str]] = (POST_AXIS,), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(mode=mode, name=name) + + self.num = num + self.axis_names = axis_names + + self.weight = init.parameter(weight, (self.num,), axis_names=axis_names) + if isinstance(self.mode, bm.TrainingMode): + self.weight = bm.TrainVar(self.weight) + + def update(self, pre_val): + return pre_val * self.weight + + +class _SynMatMul(SynComm): + def __init__( + self, + conn: connect.TwoEndConnector, + axis_names: Optional[Sequence[str]] = (PRE_AXIS, POST_AXIS), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode) + + assert isinstance(conn, connect.TwoEndConnector) + self.conn = conn + self.axis_names = axis_names + + +class DenseMM(_SynMatMul): + r"""Synaptic matrix multiplication with dense computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a dense matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + axis_names: sequence of str. The synaptic weight axis. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + axis_names: Optional[Sequence[str]] = (PRE_AXIS, POST_AXIS), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode, conn=conn) + + # weight + self.weight = init.parameter(weight, (conn.pre_num, conn.post_num), axis_names=axis_names) + if isinstance(self.mode, bm.TrainingMode): + self.weight = bm.TrainVar(self.weight) + + # connection + self.mask = bm.sharding.partition_by_axname(self.conn.require('conn_mat'), + axis_names=axis_names) + + def update(self, x): + return x @ (self.weight * self.mask) + + +class CsrMM(_SynMatMul): + r"""Synaptic matrix multiplication with CSR sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + axis_names: sequence of str. The synaptic weight axis. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + axis_names: Optional[Sequence[str]] = (PRE_AXIS, POST_AXIS), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode, conn=conn) + + # connection + self.indices, self.indptr = self.conn.require('csr') + + # weight + self.weight = init.parameter(weight, (conn.pre_num, conn.post_num), axis_names=axis_names) + if isinstance(self.mode, bm.TrainingMode): + self.weight = bm.TrainVar(self.weight) + + def update(self, x): + raise NotImplementedError + + +class CscMM(_SynMatMul): + r"""Synaptic matrix multiplication with CSC sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a CSC sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + axis_names: sequence of str. The synaptic weight axis. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + axis_names: Optional[Sequence[str]] = (PRE_AXIS, POST_AXIS), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode, conn=conn) + + +class EventCsrMM(_SynMatMul): + pass + + +class BcsrMM(_SynMatMul): + r"""Synaptic matrix multiplication with BCSR sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a BCSR sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + axis_names: sequence of str. The synaptic weight axis. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + axis_names: Optional[Sequence[str]] = (PRE_AXIS, POST_AXIS), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode, conn=conn) + + +class BcscMM(_SynMatMul): + r"""Synaptic matrix multiplication with BCSC sparse computation. + + It performs the computation of: + + .. math:: + + y = x @ M + + where :math:`y` is the postsynaptic value, :math:`x` the presynaptic value, + :math:`M` the synaptic weight using a BCSC sparse matrix. + + Args: + conn: TwoEndConnector. The connection. + weight: Synaptic weights. Can be a scalar, array, or callable function. + axis_names: sequence of str. The synaptic weight axis. + mode: The synaptic computing mode. + name: The synapse model name. + """ + + def __init__( + self, + conn: connect.TwoEndConnector, + weight: Union[float, ArrayType, Callable], + axis_names: Optional[Sequence[str]] = (PRE_AXIS, POST_AXIS), + mode: Optional[bm.Mode] = None, + name: Optional[str] = None, + ): + super().__init__(name=name, mode=mode, conn=conn) + + +class JitProbHomoMM(_SynMatMul): + pass + + +class JitProbUniformMM(_SynMatMul): + pass + + +class JitProbNormalMM(_SynMatMul): + pass diff --git a/brainpy/_src/pnn/synapses/syn_dynamics.py b/brainpy/_src/pnn/synapses/syn_dynamics.py new file mode 100644 index 000000000..d6ea1808d --- /dev/null +++ b/brainpy/_src/pnn/synapses/syn_dynamics.py @@ -0,0 +1,960 @@ +from typing import Union, Sequence, Callable, Optional + +from brainpy import math as bm +from brainpy._src.context import share +from brainpy._src.integrators.joint_eq import JointEq +from brainpy._src.integrators.ode.generic import odeint +from brainpy.types import ArrayType + +from brainpy._src.pnn.mixin import ParamDesc, AlignPost +from brainpy._src.pnn.neurons.base import PNeuGroup, pneu_doc +from brainpy._src.pnn.utils.axis_names import NEU_AXIS + +__all__ = [ + 'PSynDyn', + 'Exponential', + 'DualExponential', + 'Alpha', + 'NMDA', + 'STD', + 'STP', + 'AMPA', + 'GABAa', + 'BioNMDA', +] + + +class PSynDyn(PNeuGroup, ParamDesc): + """Parallelizable synaptic dynamics.""" + pass + + +class Exponential(PSynDyn, AlignPost): + r"""Exponential decay synapse model. + + **Model Descriptions** + + The single exponential decay synapse model assumes the release of neurotransmitter, + its diffusion across the cleft, the receptor binding, and channel opening all happen + very quickly, so that the channels instantaneously jump from the closed to the open state. + Therefore, its expression is given by + + .. math:: + + g_{\mathrm{syn}}(t)=g_{\mathrm{max}} e^{-\left(t-t_{0}\right) / \tau} + + where :math:`\tau_{delay}` is the time constant of the synaptic state decay, + :math:`t_0` is the time of the pre-synaptic spike, + :math:`g_{\mathrm{max}}` is the maximal conductance. + + Accordingly, the differential form of the exponential synapse is given by + + .. math:: + + \begin{aligned} + & \frac{d g}{d t} = -\frac{g}{\tau_{decay}}+\sum_{k} \delta(t-t_{j}^{k}). + \end{aligned} + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + + Args: + tau: float, ArrayType, Callable. The time constant of decay. [ms] + %s + """ + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + tau: Union[float, ArrayType, Callable] = 8.0, + ): + super().__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + # parameters + self.tau = self.sharding_param(tau) + + # function + self.integral = odeint(self.derivative, method=method) + + self.reset_state(self.mode) + + def derivative(self, g, t): + return -g / self.tau + + def reset_state(self, batch_size=None): + self.g = self.sharding_variable(bm.zeros, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + self.g.value = self.integral(self.g.value, t, dt) + if x is not None: + self.g.value += x + return self.g.value + + def add_current(self, x): + self.g.value += x + + +Exponential.__doc__ = Exponential.__doc__ % (pneu_doc,) + + +class DualExponential(PSynDyn): + r"""Dual exponential synapse model. + + **Model Descriptions** + + The dual exponential synapse model [1]_, also named as *difference of two exponentials* model, + is given by: + + .. math:: + + g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{ + \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right) + -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right) + + where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2` + is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic + spike, :math:`g_{\mathrm{max}}` is the maximal conductance. + + However, in practice, this formula is hard to implement. The equivalent solution is + two coupled linear differential equations [2]_: + + .. math:: + + \begin{aligned} + &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\ + &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right), + \end{aligned} + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational + Modeling Methods for Neuroscientists. + + Args: + tau_decay: float, ArrayArray, Callable. The time constant of the synaptic decay phase. [ms] + tau_rise: float, ArrayArray, Callable. The time constant of the synaptic rise phase. [ms] + %s + """ + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + tau_decay: Union[float, ArrayType, Callable] = 10.0, + tau_rise: Union[float, ArrayType, Callable] = 1., + ): + super(DualExponential, self).__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + # parameters + self.tau_rise = self.sharding_param(tau_rise) + self.tau_decay = self.sharding_param(tau_decay) + + # integrator + self.integral = odeint(JointEq(self.dg, self.dh), method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.h = self.sharding_variable(bm.zeros, batch_size) + self.g = self.sharding_variable(bm.zeros, batch_size) + + def dh(self, h, t): + return -h / self.tau_rise + + def dg(self, g, t, h): + return -g / self.tau_decay + h + + def update(self, x): + t = share.load('t') + dt = share.load('dt') + + # update synaptic variables + self.g.value, self.h.value = self.integral(self.g.value, self.h.value, t, dt=dt) + self.h += x + return self.g.value + + +DualExponential.__doc__ = DualExponential.__doc__ % (pneu_doc,) + + +class Alpha(DualExponential): + r"""Alpha synapse model. + + **Model Descriptions** + + The analytical expression of alpha synapse is given by: + + .. math:: + + g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right). + + While, this equation is hard to implement. So, let's try to convert it into the + differential forms: + + .. math:: + + \begin{aligned} + &\frac{d g}{d t}=-\frac{g}{\tau}+h \\ + &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right) + \end{aligned} + + **Model Examples** + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses, synouts + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.LIF(1) + >>> neu2 = neurons.LIF(1) + >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') + >>> plt.legend() + >>> plt.show() + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + + Args: + tau_decay: float, ArrayType, Callable. The time constant [ms] of the synaptic decay phase. + The name of this synaptic projection. + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + tau_decay: Union[float, ArrayType, Callable] = 10.0, + ): + super().__init__( + tau_decay=tau_decay, + tau_rise=tau_decay, + method=method, + name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names + ) + + +Alpha.__doc__ = Alpha.__doc__ % (pneu_doc,) + + +class NMDA(PSynDyn): + r"""NMDA synapse model. + + **Model Descriptions** + + The NMDA receptor is a glutamate receptor and ion channel found in neurons. + The NMDA receptor is one of three types of ionotropic glutamate receptors, + the other two being AMPA and kainate receptors. + + The NMDA receptor mediated conductance depends on the postsynaptic voltage. + The voltage dependence is due to the blocking of the pore of the NMDA receptor + from the outside by a positively charged magnesium ion. The channel is + nearly completely blocked at resting potential, but the magnesium block is + relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` + that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} + \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, + usually 1 mM. Thus, the channel acts as a + "coincidence detector" and only once both of these conditions are met, the + channel opens and it allows positively charged ions (cations) to flow through + the cell membrane [2]_. + + If we make the approximation that the magnesium block changes + instantaneously with voltage and is independent of the gating of the channel, + the net NMDA receptor-mediated synaptic current is given by + + .. math:: + + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} + + where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the + reversal potential. + + Simultaneously, the kinetics of synaptic state :math:`g` is given by + + .. math:: + + & g_\mathrm{NMDA} (t) = g_{max} g \\ + & \frac{d g}{dt} = -\frac{g} {\tau_{decay}}+a x(1-g) \\ + & \frac{d x}{dt} = -\frac{x}{\tau_{rise}}+ \sum_{k} \delta(t-t_{j}^{k}) + + where the decay time of NMDA currents is usually taken to be + :math:`\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\tau_{rise}` =2 ms. + + The NMDA receptor has been thought to be very important for controlling + synaptic plasticity and mediating learning and memory functions [3]_. + + + **Model Examples** + + - `(Wang, 2002) Decision making spiking model `_ + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> from brainpy import synapses, neurons + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.HH(1) + >>> neu2 = neurons.HH(1) + >>> syn1 = synapses.NMDA(neu1, neu2, bp.connect.All2All(), E=0.) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x') + >>> plt.legend() + >>> plt.show() + + .. [1] Brunel N, Wang X J. Effects of neuromodulation in a + cortical network model of object working memory dominated + by recurrent inhibition[J]. + Journal of computational neuroscience, 2001, 11(1): 63-85. + .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and + Eric Gouaux. "Subunit arrangement and function in NMDA receptors." + Nature 438, no. 7065 (2005): 185-192. + .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New + England journal of medicine, 361(3), p.302. + .. [4] https://en.wikipedia.org/wiki/NMDA_receptor + + Args: + tau_decay: float, ArrayType, Callable. The time constant of the synaptic decay phase. Default 100 [ms] + tau_rise: float, ArrayType, Callable. The time constant of the synaptic rise phase. Default 2 [ms] + a: float, ArrayType, Callable. Default 0.5 ms^-1. + %s + """ + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + a: Union[float, ArrayType, Callable] = 0.5, + tau_decay: Union[float, ArrayType, Callable] = 100., + tau_rise: Union[float, ArrayType, Callable] = 2., + ): + super(NMDA, self).__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + # parameters + self.tau_decay = self.sharding_param(tau_decay) + self.tau_rise = self.sharding_param(tau_rise) + self.a = self.sharding_param(a) + + # integral + self.integral = odeint(method=method, f=JointEq(self.dg, self.dx)) + + self.reset_state(self.mode) + + def dg(self, g, t, x): + return -g / self.tau_decay + self.a * x * (1 - g) + + def dx(self, x, t): + return -x / self.tau_rise + + def reset_state(self, batch_size=None): + self.g = self.sharding_variable(bm.zeros, batch_size) + self.x = self.sharding_variable(bm.zeros, batch_size) + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + self.g.value, self.x.value = self.integral(self.g, self.x, t, dt=dt) + self.x += pre_spike + return self.g.value + + +NMDA.__doc__ = NMDA.__doc__ % (pneu_doc,) + + +class STD(PSynDyn): + r"""Synaptic output with short-term depression. + + This model filters the synaptic current by the following equation: + + .. math:: + + I_{syn}^+(t) = I_{syn}^-(t) * x + + where :math:`x` is the normalized variable between 0 and 1, and + :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STD filtering. + + Moreover, :math:`x` is updated according to the dynamics of: + + .. math:: + + \frac{dx}{dt} = \frac{1-x}{\tau} - U * x * \delta(t-t_{spike}) + + where :math:`U` is the fraction of resources used per action potential, + :math:`\tau` is the time constant of recovery of the synaptic vesicles. + + Args: + tau: float, ArrayType, Callable. The time constant of recovery of the synaptic vesicles. + U: float, ArrayType, Callable. The fraction of resources used per action potential. + %s + """ + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + tau: Union[float, ArrayType, Callable] = 200., + U: Union[float, ArrayType, Callable] = 0.07, + ): + super(STD, self).__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + # parameters + self.tau = self.sharding_param(tau) + self.U = self.sharding_param(U) + + # integral function + self.integral = odeint(lambda x, t: (1 - x) / self.tau, method=method) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.x = self.sharding_variable(bm.ones, batch_size) + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + x = self.integral(self.x.value, t, dt) + self.x.value = bm.where(pre_spike, x - self.U * self.x, x) + return self.x.value + + +STD.__doc__ = STD.__doc__ % (pneu_doc,) + + +class STP(PSynDyn): + r"""Synaptic output with short-term plasticity. + + This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. + + .. math:: + + I_{syn}^+(t) = I_{syn}^-(t) * x * u + + where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before + and after STP filtering, :math:`x` denotes the fraction of resources that remain available + after neurotransmitter depletion, and :math:`u` represents the fraction of available + resources ready for use (release probability). + + The dynamics of :math:`u` and :math:`x` are governed by + + .. math:: + + \begin{aligned} + \frac{du}{dt} & = & -\frac{u}{\tau_f}+U(1-u^-)\delta(t-t_{sp}), \\ + \frac{dx}{dt} & = & \frac{1-x}{\tau_d}-u^+x^-\delta(t-t_{sp}), \\ + \tag{1}\end{aligned} + + where :math:`t_{sp}` denotes the spike time and :math:`U` is the increment + of :math:`u` produced by a spike. :math:`u^-, x^-` are the corresponding + variables just before the arrival of the spike, and :math:`u^+` + refers to the moment just after the spike. + + Args: + tau_f: float, ArrayType, Callable. The time constant of short-term facilitation. + tau_d: float, ArrayType, Callable. The time constant of short-term depression. + U: float, ArrayType, Callable. The fraction of resources used per action potential. + %s + """ + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + U: Union[float, ArrayType, Callable] = 0.15, + tau_f: Union[float, ArrayType, Callable] = 1500., + tau_d: Union[float, ArrayType, Callable] = 200., + ): + super(STP, self).__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + # parameters + self.tau_f = self.sharding_param(tau_f) + self.tau_d = self.sharding_param(tau_d) + self.U = self.sharding_param(U) + self.method = method + + # integral function + self.integral = odeint(self.derivative, method=self.method) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.x = self.sharding_variable(bm.ones, batch_size) + self.u = self.sharding_variable(bm.ones, batch_size) + self.u.fill_(self.U) + + @property + def derivative(self): + du = lambda u, t: self.U - u / self.tau_f + dx = lambda x, t: (1 - x) / self.tau_d + return JointEq([du, dx]) + + def update(self, pre_spike): + t = share.load('x') + dt = share.load('dt') + u, x = self.integral(self.u.value, self.x.value, t, dt) + u = bm.where(pre_spike, u + self.U * (1 - self.u), u) + x = bm.where(pre_spike, x - u * self.x, x) + self.x.value = x + self.u.value = u + return u * x + + +STP.__doc__ = STP.__doc__ % (pneu_doc,) + + +class AMPA(PSynDyn): + r"""AMPA synapse model. + + **Model Descriptions** + + AMPA receptor is an ionotropic receptor, which is an ion channel. + When it is bound by neurotransmitters, it will immediately open the + ion channel, causing the change of membrane potential of postsynaptic neurons. + + A classical model is to use the Markov process to model ion channel switch. + Here :math:`g` represents the probability of channel opening, :math:`1-g` + represents the probability of ion channel closing, and :math:`\alpha` and + :math:`\beta` are the transition probability. Because neurotransmitters can + open ion channels, the transfer probability from :math:`1-g` to :math:`g` + is affected by the concentration of neurotransmitters. We denote the concentration + of neurotransmitters as :math:`[T]` and get the following Markov process. + + .. image:: ../../../_static/synapse_markov.png + :align: center + + We obtained the following formula when describing the process by a differential equation. + + .. math:: + + \frac{ds}{dt} =\alpha[T](1-g)-\beta g + + where :math:`\alpha [T]` denotes the transition probability from state :math:`(1-g)` + to state :math:`(g)`; and :math:`\beta` represents the transition probability of + the other direction. :math:`\alpha` is the binding constant. :math:`\beta` is the + unbinding constant. :math:`[T]` is the neurotransmitter concentration, and + has the duration of 0.5 ms. + + Moreover, the post-synaptic current on the post-synaptic neuron is formulated as + + .. math:: + + I_{syn} = g_{max} g (V-E) + + where :math:`g_{max}` is the maximum conductance, and `E` is the reverse potential. + + **Model Examples** + + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.HH(1) + >>> neu2 = neurons.HH(1) + >>> syn1 = synapses.AMPA(neu1, neu2, bp.connect.All2All()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.legend() + >>> plt.show() + + .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations + and implications for stimulus processing[J]. Proceedings of the + National Academy of Sciences, 2012, 109(45): 18553-18558. + + Args: + alpha: float, ArrayType, Callable. Binding constant. + beta: float, ArrayType, Callable. Unbinding constant. + T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by + a pre-synaptic spike.. Default 1 [mM]. + T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] + %s + """ + + supported_modes = (bm.NonBatchingMode,) + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + alpha: Union[float, ArrayType, Callable] = 0.98, + beta: Union[float, ArrayType, Callable] = 0.18, + T: Union[float, ArrayType, Callable] = 0.5, + T_dur: Union[float, ArrayType, Callable] = 0.5, + ): + super(AMPA, self).__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + # parameters + self.alpha = self.sharding_param(alpha) + self.beta = self.sharding_param(beta) + self.T = self.sharding_param(T) + self.T_duration = self.sharding_param(T_dur) + + # functions + self.integral = odeint(method=method, f=self.dg) + + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.g = self.sharding_variable(bm.zeros, batch_size) + self.spike_arrival_time = self.sharding_variable(bm.ones, batch_size) + self.spike_arrival_time.fill(-1e7) + + def dg(self, g, t, TT): + return self.alpha * TT * (1 - g) - self.beta * g + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) + TT = ((t - self.spike_arrival_time) < self.T_duration) * self.T + self.g.value = self.integral(self.g, t, TT, dt) + return self.g.value + + +AMPA.__doc__ = AMPA.__doc__ % (pneu_doc,) + + +class GABAa(AMPA): + r"""GABAa synapse model. + + **Model Descriptions** + + GABAa synapse model has the same equation with the `AMPA synapse <./brainmodels.synapses.AMPA.rst>`_, + + .. math:: + + \frac{d g}{d t}&=\alpha[T](1-g) - \beta g \\ + I_{syn}&= - g_{max} g (V - E) + + but with the difference of: + + - Reversal potential of synapse :math:`E` is usually low, typically -80. mV + - Activating rate constant :math:`\alpha=0.53` + - De-activating rate constant :math:`\beta=0.18` + - Transmitter concentration :math:`[T]=1\,\mu ho(\mu S)` when synapse is + triggered by a pre-synaptic spike, with the duration of 1. ms. + + + .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity + on the integrative properties of neocortical pyramidal neurons + in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547. + + Args: + alpha: float, ArrayType, Callable. Binding constant. Default 0.062 + beta: float, ArrayType, Callable. Unbinding constant. Default 3.57 + T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by + a pre-synaptic spike.. Default 1 [mM]. + T_dur: float, ArrayType, Callable. Transmitter concentration duration time + after being triggered. Default 1 [ms] + %s + """ + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + alpha: Union[float, ArrayType, Callable] = 0.53, + beta: Union[float, ArrayType, Callable] = 0.18, + T: Union[float, ArrayType, Callable] = 1., + T_dur: Union[float, ArrayType, Callable] = 1., + ): + super(GABAa, self).__init__(alpha=alpha, + beta=beta, + T=T, + T_dur=T_dur, + method=method, + name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + +GABAa.__doc__ = GABAa.__doc__ % (pneu_doc,) + + +class BioNMDA(PSynDyn): + r"""Biological NMDA synapse model. + + **Model Descriptions** + + The NMDA receptor is a glutamate receptor and ion channel found in neurons. + The NMDA receptor is one of three types of ionotropic glutamate receptors, + the other two being AMPA and kainate receptors. + + The NMDA receptor mediated conductance depends on the postsynaptic voltage. + The voltage dependence is due to the blocking of the pore of the NMDA receptor + from the outside by a positively charged magnesium ion. The channel is + nearly completely blocked at resting potential, but the magnesium block is + relieved if the cell is depolarized. The fraction of channels :math:`g_{\infty}` + that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-a V} + \frac{[{Mg}^{2+}]_{o}} {b})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration, + usually 1 mM. Thus, the channel acts as a + "coincidence detector" and only once both of these conditions are met, the + channel opens and it allows positively charged ions (cations) to flow through + the cell membrane [2]_. + + If we make the approximation that the magnesium block changes + instantaneously with voltage and is independent of the gating of the channel, + the net NMDA receptor-mediated synaptic current is given by + + .. math:: + + I_{syn} = g_\mathrm{NMDA}(t) (V(t)-E) \cdot g_{\infty} + + where :math:`V(t)` is the post-synaptic neuron potential, :math:`E` is the + reversal potential. + + Simultaneously, the kinetics of synaptic state :math:`g` is determined by a 2nd-order kinetics [1]_: + + .. math:: + + & \frac{d g}{dt} = \alpha_1 x (1 - g) - \beta_1 g \\ + & \frac{d x}{dt} = \alpha_2 [T] (1 - x) - \beta_2 x + + where :math:`\alpha_1, \beta_1` refers to the conversion rate of variable g and + :math:`\alpha_2, \beta_2` refers to the conversion rate of variable x. + + The NMDA receptor has been thought to be very important for controlling + synaptic plasticity and mediating learning and memory functions [3]_. + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.HH(1) + >>> neu2 = neurons.HH(1) + >>> syn1 = synapses.BioNMDA(neu1, neu2, bp.connect.All2All()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 5.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.x']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.x'], label='x') + >>> plt.legend() + >>> plt.show() + + .. [1] Devaney A J . Mathematical Foundations of Neuroscience[M]. + Springer New York, 2010: 162. + .. [2] Furukawa, Hiroyasu, Satinder K. Singh, Romina Mancusso, and + Eric Gouaux. "Subunit arrangement and function in NMDA receptors." + Nature 438, no. 7065 (2005): 185-192. + .. [3] Li, F. and Tsien, J.Z., 2009. Memory and the NMDA receptors. The New + England journal of medicine, 361(3), p.302. + .. [4] https://en.wikipedia.org/wiki/NMDA_receptor + + + Args: + alpha1: float, ArrayType, Callable. The conversion rate of g from inactive to active. Default 2 ms^-1. + beta1: float, ArrayType, Callable. The conversion rate of g from active to inactive. Default 0.01 ms^-1. + alpha2: float, ArrayType, Callable. The conversion rate of x from inactive to active. Default 1 ms^-1. + beta2: float, ArrayType, Callable. The conversion rate of x from active to inactive. Default 0.5 ms^-1. + T: float, ArrayType, Callable. Transmitter concentration when synapse is triggered by + a pre-synaptic spike.. Default 1 [mM]. + T_dur: float, ArrayType, Callable. Transmitter concentration duration time after being triggered. Default 1 [ms] + %s + """ + supported_modes = (bm.NonBatchingMode,) + + def __init__( + self, + size: Union[int, Sequence[int]], + keep_size: bool = False, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + method: str = 'exp_auto', + name: Optional[str] = None, + mode: Optional[bm.Mode] = None, + + # synapse parameters + alpha1: Union[float, ArrayType, Callable] = 2., + beta1: Union[float, ArrayType, Callable] = 0.01, + alpha2: Union[float, ArrayType, Callable] = 1., + beta2: Union[float, ArrayType, Callable] = 0.5, + T: Union[float, ArrayType, Callable] = 1., + T_dur: Union[float, ArrayType, Callable] = 0.5, + ): + super(BioNMDA, self).__init__(name=name, + mode=mode, + size=size, + keep_size=keep_size, + axis_names=axis_names) + + # parameters + self.beta1 = self.sharding_param(beta1) + self.beta2 = self.sharding_param(beta2) + self.alpha1 = self.sharding_param(alpha1) + self.alpha2 = self.sharding_param(alpha2) + self.T = self.sharding_param(T) + self.T_dur = self.sharding_param(T_dur) + + # integral + self.integral = odeint(method=method, f=JointEq([self.dg, self.dx])) + + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.g = self.sharding_variable(bm.zeros, batch_size) + self.x = self.sharding_variable(bm.zeros, batch_size) + self.spike_arrival_time = self.sharding_variable(bm.ones, batch_size) + self.spike_arrival_time.fill(-1e7) + + def dg(self, g, t, x): + return self.alpha1 * x * (1 - g) - self.beta1 * g + + def dx(self, x, t, T): + return self.alpha2 * T * (1 - x) - self.beta2 * x + + def update(self, pre_spike): + t = share.load('t') + dt = share.load('dt') + self.spike_arrival_time.value = bm.where(pre_spike, t, self.spike_arrival_time) + T = ((t - self.spike_arrival_time) < self.T_dur) * self.T + self.g.value, self.x.value = self.integral(self.g, self.x, t, T, dt) + return self.g.value + + +BioNMDA.__doc__ = BioNMDA.__doc__ % (pneu_doc,) diff --git a/brainpy/_src/pnn/synapses/syn_output.py b/brainpy/_src/pnn/synapses/syn_output.py new file mode 100644 index 000000000..f381b64b2 --- /dev/null +++ b/brainpy/_src/pnn/synapses/syn_output.py @@ -0,0 +1,159 @@ + +from typing import Union, Optional, Sequence + +import numpy as np +from brainpy import math as bm, initialize as init +from brainpy._src.dynsys import DynamicalSystemNS +from brainpy.types import ArrayType + +from brainpy._src.pnn.mixin import ParamDesc +from brainpy._src.pnn.utils import NEU_AXIS + +__all__ = [ + 'PSynOut', + 'COBA', + 'CUBA', + 'MgBlock' +] + + +class PSynOut(DynamicalSystemNS, ParamDesc): + def __init__( + self, + name: Optional[str] = None, + ): + super().__init__(name=name) + self._conductance = None + + def bind_cond(self, conductance): + self._conductance = conductance + + def unbind_cond(self): + self._conductance = None + + def __call__(self, *args, **kwargs): + if self._conductance is None: + raise ValueError(f'Please first pack data at the current step using ' + f'".bind_cond(data)". {self}') + ret = super().__call__(self._conductance, *args, **kwargs) + return ret + + +class COBA(PSynOut): + r"""Conductance-based synaptic output. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) + + Parameters + ---------- + E: float, ArrayType, ndarray + The reversal potential. + name: str + The model name. + + See Also + -------- + CUBA + """ + + def __init__( + self, + E: Union[float, ArrayType] = 0., + axis_names: Optional[Sequence[str]] = (NEU_AXIS, ), + name: Optional[str] = None, + ): + super().__init__(name=name) + + self.axis_names = axis_names + self.E = init.parameter(E, np.shape(E), axis_names=axis_names) + + def update(self, conductance, potential): + return conductance * (self.E - potential) + + +class CUBA(PSynOut): + r"""Current-based synaptic output. + + Given the conductance, this model outputs the post-synaptic current with a identity function: + + .. math:: + + I_{\mathrm{syn}}(t) = g_{\mathrm{syn}}(t) + + Parameters + ---------- + name: str + The model name. + + + See Also + -------- + COBA + """ + + def __init__( + self, + name: Optional[str] = None, + ): + super().__init__(name=name) + + def update(self, conductance, potential=None): + return conductance + + +class MgBlock(PSynOut): + r"""Synaptic output based on Magnesium blocking. + + Given the synaptic conductance, the model output the post-synaptic current with + + .. math:: + + I_{syn}(t) = g_{\mathrm{syn}}(t) (E - V(t)) g_{\infty}(V,[{Mg}^{2+}]_{o}) + + where The fraction of channels :math:`g_{\infty}` that are not blocked by magnesium can be fitted to + + .. math:: + + g_{\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\alpha V} \frac{[{Mg}^{2+}]_{o}} {\beta})^{-1} + + Here :math:`[{Mg}^{2+}]_{o}` is the extracellular magnesium concentration. + + Parameters + ---------- + E: float, ArrayType + The reversal potential for the synaptic current. [mV] + alpha: float, ArrayType + Binding constant. Default 0.062 + beta: float, ArrayType + Unbinding constant. Default 3.57 + cc_Mg: float, ArrayType + Concentration of Magnesium ion. Default 1.2 [mM]. + name: str + The model name. + """ + def __init__( + self, + E: Union[float, ArrayType] = 0., + cc_Mg: Union[float, ArrayType] = 1.2, + alpha: Union[float, ArrayType] = 0.062, + beta: Union[float, ArrayType] = 3.57, + axis_names: Optional[Sequence[str]] = (NEU_AXIS,), + name: Optional[str] = None, + ): + super().__init__(name=name) + + self.axis_names = axis_names + self.E = init.parameter(E, np.shape(E), axis_names=axis_names) + self.cc_Mg = init.parameter(cc_Mg, np.shape(cc_Mg), axis_names=axis_names) + self.alpha = init.parameter(alpha, np.shape(alpha), axis_names=axis_names) + self.beta = init.parameter(alpha, np.shape(beta), axis_names=axis_names) + + def update(self, conductance, potential): + return conductance * (self.E - potential) / (1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * potential)) + + + diff --git a/brainpy/_src/pnn/utils/__init__.py b/brainpy/_src/pnn/utils/__init__.py new file mode 100644 index 000000000..3362213b5 --- /dev/null +++ b/brainpy/_src/pnn/utils/__init__.py @@ -0,0 +1,6 @@ + + +from .axis_names import * +from .init import * + + diff --git a/brainpy/_src/pnn/utils/axis_names.py b/brainpy/_src/pnn/utils/axis_names.py new file mode 100644 index 000000000..460d7b712 --- /dev/null +++ b/brainpy/_src/pnn/utils/axis_names.py @@ -0,0 +1,9 @@ + +NEU_AXIS = 'neuron' +PRE_AXIS = 'pre' +POST_AXIS = 'post' +SYN_AXIS = 'synapse' +BATCH_AXIS = 'batch' +TIME_AXIS = 'time' + + diff --git a/brainpy/_src/pnn/utils/axis_rules.py b/brainpy/_src/pnn/utils/axis_rules.py new file mode 100644 index 000000000..5d5a6f78a --- /dev/null +++ b/brainpy/_src/pnn/utils/axis_rules.py @@ -0,0 +1,8 @@ +from .axis_names import NEU_AXIS, POST_AXIS, BATCH_AXIS + +AXIS_RULE = { + NEU_AXIS: 'N', + POST_AXIS: 'N', + BATCH_AXIS: 'B', +} + diff --git a/brainpy/_src/pnn/utils/init.py b/brainpy/_src/pnn/utils/init.py new file mode 100644 index 000000000..f1001805a --- /dev/null +++ b/brainpy/_src/pnn/utils/init.py @@ -0,0 +1,30 @@ +__all__ = [ + 'DelayedInit', +] + + +class DelayedInit(object): + """Delayed initialization. + """ + + def __init__( + self, + cls: type, + identifier, + *args, + **kwargs + ): + self.cls = cls + self.args = args + self.kwargs = kwargs + self._identifier = identifier + + def __call__(self, *args, **kwargs): + return self.cls(*self.args, *args, **self.kwargs, **kwargs) + + def init(self, *args, **kwargs): + return self.__call__(*args, **kwargs) + + @classmethod + def __class_getitem__(cls, item): + return cls diff --git a/brainpy/_src/psnn/__init__.py b/brainpy/_src/psnn/__init__.py deleted file mode 100644 index a22576256..000000000 --- a/brainpy/_src/psnn/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# -# The module for ``Large-scale spiking neural networks`` -# - - diff --git a/brainpy/_src/rates/populations.py b/brainpy/_src/rates/populations.py index 85e927707..6b05cf212 100644 --- a/brainpy/_src/rates/populations.py +++ b/brainpy/_src/rates/populations.py @@ -5,7 +5,7 @@ from brainpy import math as bm from brainpy._src.context import share from brainpy._src.dynsys import NeuGroupNS -from brainpy._src.dyn.neurons.noise_groups import OUProcess +from brainpy._src.neurons.noise_groups import OUProcess from brainpy._src.initialize import (Initializer, Uniform, parameter, @@ -1026,11 +1026,8 @@ def __init__( if self.input_var: self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population - if bm.any(self.noise_e != 0) or bm.any(self.noise_i != 0): - self.rng = bm.random.default_rng(seed) def reset(self, batch_size=None): - self.rng.seed(self.seed) self.reset_state(batch_size) def reset_state(self, batch_size=None): @@ -1057,13 +1054,13 @@ def update(self, x1=None, x2=None): de = -self.e + self.beta_e * bm.maximum(input_e, 0.) if bm.any(self.noise_e != 0.): - de += self.rng.randn(self.varshape) * self.noise_e + de += bm.random.randn(self.varshape) * self.noise_e de = de / self.tau_e self.e.value = bm.maximum(self.e + de * dt, 0.) di = -self.i + self.beta_i * bm.maximum(input_i, 0.) if bm.any(self.noise_i != 0.): - di += self.rng.randn(self.varshape) * self.noise_i + di += bm.random.randn(self.varshape) * self.noise_i di = di / self.tau_i self.i.value = bm.maximum(self.i + di * dt, 0.) return self.e.value diff --git a/brainpy/_src/runners.py b/brainpy/_src/runners.py index ff56729f2..91c898701 100644 --- a/brainpy/_src/runners.py +++ b/brainpy/_src/runners.py @@ -21,6 +21,7 @@ from brainpy.errors import RunningError from brainpy.types import ArrayType, Output, Monitor + __all__ = [ 'DSRunner', ] diff --git a/brainpy/_src/synapses_v2/abstract_synapses.py b/brainpy/_src/synapses_v2/abstract_synapses.py index 249e091e0..16783f18e 100644 --- a/brainpy/_src/synapses_v2/abstract_synapses.py +++ b/brainpy/_src/synapses_v2/abstract_synapses.py @@ -120,21 +120,22 @@ def update(self, pre_spike, post_v=None): else: if self.comp_method == 'sparse': if self.stp is None: - f = lambda s: bm.event_csr_matvec(self.g_max, - self.conn_mask[0], - self.conn_mask[1], - s, - shape=(self.pre_num, self.post_num), - transpose=True) + f = lambda s: bm.event.csrmv(self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.pre_num, self.post_num), + transpose=True) if isinstance(self.mode, bm.BatchingMode): f = vmap(f) else: - f = lambda s: bm.cusparse_csr_matvec(self.g_max, - self.conn_mask[0], - self.conn_mask[1], - s, - shape=(self.pre_num, self.post_num), - transpose=True) + f = lambda s: bm.sparse.csrmv(self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.pre_num, self.post_num), + transpose=True, + method='cusparse') if isinstance(self.mode, bm.BatchingMode): f = vmap(f) post_vs = f(pre_spike) @@ -275,13 +276,14 @@ def update(self, pre_spike, post_v=None): post_vs = self._syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': - f = lambda s: bm.cusparse_csr_matvec( + f = lambda s: bm.sparse.csrmv( self.g_max, self.conn_mask[0], self.conn_mask[1], s, shape=(self.conn.pre_num, self.conn.post_num), - transpose=True + transpose=True, + method='cusparse' ) if isinstance(self.mode, bm.BatchingMode): f = vmap(f) @@ -396,4 +398,3 @@ def __init__( stp=stp, name=name, mode=mode) - diff --git a/brainpy/_src/synapses_v2/syn_outs.py b/brainpy/_src/synapses_v2/syn_outs.py index 727e878a1..5492513da 100644 --- a/brainpy/_src/synapses_v2/syn_outs.py +++ b/brainpy/_src/synapses_v2/syn_outs.py @@ -2,7 +2,7 @@ from typing import Union -from brainpy._src.dyn.synapses_v2.base import SynOutNS +from brainpy._src.synapses_v2.base import SynOutNS from brainpy.math import exp from brainpy.types import ArrayType diff --git a/brainpy/_src/synapses_v2/syn_plasticity.py b/brainpy/_src/synapses_v2/syn_plasticity.py index 490b41286..384dbafef 100644 --- a/brainpy/_src/synapses_v2/syn_plasticity.py +++ b/brainpy/_src/synapses_v2/syn_plasticity.py @@ -6,7 +6,7 @@ from brainpy._src.context import share from brainpy import math as bm, tools -from brainpy._src.dyn.synapses_v2.base import SynSTPNS +from brainpy._src.synapses_v2.base import SynSTPNS from brainpy._src.initialize import variable_, OneInit, parameter from brainpy._src.integrators import odeint, JointEq from brainpy.types import ArrayType, Shape diff --git a/brainpy/_src/tests/test_access_methods.py b/brainpy/_src/tests/test_access_methods.py new file mode 100644 index 000000000..1e361ffbd --- /dev/null +++ b/brainpy/_src/tests/test_access_methods.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- + +import jax.numpy as jnp +import brainpy as bp + +bp.ode.set_default_odeint('rk4') + + +class GABAa(bp.TwoEndConn): + def __init__(self, pre, post, conn, delay=0., g_max=0.1, E=-75., + alpha=12., beta=0.1, T=1.0, T_duration=1.0, **kwargs): + super(GABAa, self).__init__(pre=pre, post=post, conn=conn, **kwargs) + + # parameters + self.g_max = g_max + self.E = E + self.alpha = alpha + self.beta = beta + self.T = T + self.T_duration = T_duration + self.delay = delay + + # connections + self.conn_mat = self.conn.requires('conn_mat') + self.size = jnp.shape(self.conn_mat) + + # variables + self.t_last_pre_spike = jnp.ones(self.size) * -1e7 + self.s = jnp.zeros(self.size) + + self.int_s = bp.odeint(self.dev) + + def dev(self, s, t, TT, alpha, beta): + return alpha * TT * (1 - s) - beta * s + + def update(self, t, dt, **kwargs): + spike = jnp.reshape(self.pre.spikes, (self.pre.num, 1)) * self.conn_mat + self.t_last_pre_spike = jnp.where(spike, t, self.t_last_pre_spike) + TT = ((t - self.t_last_pre_spike) < self.T_duration) * self.T + self.s = self.int_s(self.s, t, TT, self.alpha, self.beta) + self.post.inputs -= jnp.sum(self.g_max * self.s, axis=0) * (self.post.V - self.E) + + +class HH(bp.dyn.NeuGroup): + def __init__(self, size, ENa=55., EK=-90., EL=-65, + C=1.0, gNa=35., gK=9., gL=0.1, V_th=20., + phi=5.0, **kwargs): + super(HH, self).__init__(size=size, **kwargs) + + # parameters + self.ENa = ENa + self.EK = EK + self.EL = EL + self.C = C + self.gNa = gNa + self.gK = gK + self.gL = gL + self.V_th = V_th + self.phi = phi + + # variables + self.V = jnp.ones(self.num) * -65. + self.h = jnp.ones(self.num) * 0.6 + self.n = jnp.ones(self.num) * 0.32 + self.spikes = jnp.zeros(self.num) + self.inputs = jnp.zeros(self.num) + + self.integral = bp.odeint(self.dev) + + def dev(self, V, h, n, t, Iext): + alpha = 0.07 * jnp.exp(-(V + 58) / 20) + beta = 1 / (jnp.exp(-0.1 * (V + 28)) + 1) + dhdt = alpha * (1 - h) - beta * h + + alpha = -0.01 * (V + 34) / (jnp.exp(-0.1 * (V + 34)) - 1) + beta = 0.125 * jnp.exp(-(V + 44) / 80) + dndt = alpha * (1 - n) - beta * n + + m_alpha = -0.1 * (V + 35) / (jnp.exp(-0.1 * (V + 35)) - 1) + m_beta = 4 * jnp.exp(-(V + 60) / 18) + m = m_alpha / (m_alpha + m_beta) + INa = self.gNa * m ** 3 * h * (V - self.ENa) + IK = self.gK * n ** 4 * (V - self.EK) + IL = self.gL * (V - self.EL) + dVdt = (- INa - IK - IL + Iext) / self.C + + return dVdt, self.phi * dhdt, self.phi * dndt + + def update(self, t, _i, **kwargs): + V, h, n = self.integral(self.V, self.h, self.n, t, self.inputs) + self.spikes[:] = (self.V < self.V_th) * (V >= self.V_th) + self.V[:] = V + self.h[:] = h + self.n[:] = n + self.inputs[:] = 0 + + +def test1(): + bp.math.random.seed(123) + num = 10 + neu = HH(num) + neu.V = -70. + bp.math.random.normal(size=num) * 20 + + syn = GABAa(pre=neu, post=neu, conn=bp.connect.All2All(include_self=False)) + syn.g_max = 0.1 / num + + net = bp.Network(neu=neu, syn=syn) + + for method in ['relative', 'absolute']: + print(f'Method: {method}\n') + print('vars:') + print('-----') + print('neu.vars()', list(neu.vars(method).keys())) + print('syn.vars()', list(syn.vars(method).keys())) + print('net.vars()', list(net.vars(method).keys())) + print() + + print('nodes:') + print('------') + print('neu.nodes()', list(neu.nodes(method).keys())) + print('syn.nodes()', list(syn.nodes(method).keys())) + print('net.nodes()', list(net.nodes(method).keys())) + print() diff --git a/brainpy/_src/tests/test_base_classes.py b/brainpy/_src/tests/test_base_classes.py new file mode 100644 index 000000000..9c095a30e --- /dev/null +++ b/brainpy/_src/tests/test_base_classes.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- + +import unittest + +import brainpy as bp + + +class TestDynamicalSystem(unittest.TestCase): + def test_delay(self): + A = bp.neurons.LIF(1) + B = bp.neurons.LIF(1) + C = bp.neurons.LIF(1) + A2B = bp.synapses.Exponential(A, B, bp.conn.All2All(), delay_step=1) + A2C = bp.synapses.Exponential(A, C, bp.conn.All2All(), delay_step=None) + net = bp.Network(A, B, C, A2B, A2C) + + runner = bp.DSRunner(net,) + runner.run(10.) + + diff --git a/brainpy/_src/tests/test_check.py b/brainpy/_src/tests/test_check.py new file mode 100644 index 000000000..a04105486 --- /dev/null +++ b/brainpy/_src/tests/test_check.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + + +import unittest + +from brainpy import check as checking + + +class TestUtils(unittest.TestCase): + def test_check_shape(self): + all_shapes = [ + (1, 2, 3), + (1, 4), + (10, 2, 4) + ] + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=-1) + self.assertEqual(free_shape, [3, 4, 4]) + self.assertEqual(fixed_shapes, [10, 2]) + + def test_check_shape2(self): + all_shapes = [ + (1, 2, 3, 8,), + (10, 1, 4, 10), + (10, 2, 4, 100) + ] + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[2, -1]) + print(free_shape) + print(fixed_shapes) + self.assertEqual(free_shape, [[3, 8], [4, 10], [4, 100]]) + self.assertEqual(fixed_shapes, [10, 2]) + + def test_check_shape3(self): + all_shapes = [ + (1, 2, 3, 8,), + (10, 1, 4, 10), + (10, 2, 4, 100) + ] + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, 2, -1]) + print(free_shape) + print(fixed_shapes) + self.assertEqual(free_shape, [[1, 3, 8], [10, 4, 10], [10, 4, 100]]) + self.assertEqual(fixed_shapes, [2]) + + def test_check_shape4(self): + all_shapes = [ + (1, 2, 3, 8,), + (10, 1, 4, 10), + (10, 2, 4, 100) + ] + with self.assertRaises(ValueError): + free_shape, fixed_shapes = checking.check_shape(all_shapes, free_axes=[0, -1]) diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py new file mode 100644 index 000000000..e311a664e --- /dev/null +++ b/brainpy/_src/tests/test_dyn_runner.py @@ -0,0 +1,133 @@ +# -*- coding: utf-8 -*- + + +import unittest +import brainpy as bp +import brainpy.math as bm + + +class TestDSRunner(unittest.TestCase): + def test1(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self, tdi): + self.i += 1 + + ds = ExampleDS() + runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_t_and_dt(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self, tdi): + self.i += 1 * tdi.dt + + runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_DSView(self): + class EINet(bp.Network): + def __init__(self, scale=1.0, method='exp_auto'): + super(EINet, self).__init__() + + # network size + num_exc = int(800 * scale) + num_inh = int(200 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) + self.E = bp.neurons.LIF(num_exc, **pars, method=method) + self.I = bp.neurons.LIF(num_inh, **pars, method=method) + self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. + self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + + net = EINet(scale=1., method='exp_auto') + # with JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) + + # without JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)], + jit=False).run(0.2) + + + +class TestMemoryEfficient(unittest.TestCase): + pass + + + + + + +# class TestMonitor(TestCase): +# def test_1d_array(self): +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones(1) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 +# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) +# +# def test_2d_array(): +# set(dt=0.1) +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones((2, 2)) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# def test_monitor_with_every(): +# set(dt=0.1) +# +# # try1: 2d array +# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try1.run(100.) +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# # try2: 1d array +# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try2.a = np.array([1., 1.]) +# try2.run(100.) +# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 2, axis=1) +# assert np.allclose(series, try2.mon.a) +# +# # try2: scalar +# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try3.a = 1. +# try3.run(100.) +# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# assert np.allclose(series, try3.mon.a) diff --git a/brainpy/_src/tests/test_network.py b/brainpy/_src/tests/test_network.py new file mode 100644 index 000000000..3c3afe310 --- /dev/null +++ b/brainpy/_src/tests/test_network.py @@ -0,0 +1,51 @@ +import brainpy as bp +import unittest + + +class TestNetDefinition(unittest.TestCase): + def test_define_net1(self): + E = bp.neurons.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60., + tau=20., tau_ref=5., method='exp_auto', + V_initializer=bp.init.Normal(-60., 2.)) + + I = bp.neurons.LIF(800, V_rest=-60., V_th=-50., V_reset=-60., + tau=20., tau_ref=5., method='exp_auto', + V_initializer=bp.init.Normal(-60., 2.)) + + E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), g_max=0.6, + tau=5., output=bp.synouts.COBA(E=0.), + method='exp_auto') + + E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), g_max=0.6, + tau=5., output=bp.synouts.COBA(E=0.), + method='exp_auto') + + I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), g_max=6.7, + tau=10., output=bp.synouts.COBA(E=-80.), + method='exp_auto') + + I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), g_max=6.7, + tau=10., output=bp.synouts.COBA(E=-80.), + method='exp_auto') + + net = bp.Network(E2E, E2I, I2E, I2I, E=E, I=I) + + runner1 = bp.DSRunner(net, + monitors=['E.spike', 'I.spike'], + inputs=[('E.input', 20.), ('I.input', 20.)]) + + runner2 = bp.DSRunner(net, + monitors=[('E.spike', E.spike), ('I.spike', I.spike)], + inputs=[(E.input, 20.), (I.input, 20.)]) + + runner3 = bp.DSRunner(net, + monitors=[('E.spike', E.spike), 'I.spike'], + inputs=[(E.input, 20.), (I.input, 20.)]) + + runner4 = bp.DSRunner(net, + monitors={'E.spike': E.spike, 'I.spike': I.spike}, + inputs=[(E.input, 20.), (I.input, 20.)]) + + bp.math.clear_buffer_memory() + + diff --git a/brainpy/_src/tests/test_pickle.py b/brainpy/_src/tests/test_pickle.py new file mode 100644 index 000000000..2ae6a1345 --- /dev/null +++ b/brainpy/_src/tests/test_pickle.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +import brainpy as bp + +import unittest + +import pickle + + +class TestPickle(unittest.TestCase): + def __init__(self, *args, **kwargs): + super(TestPickle, self).__init__(*args, **kwargs) + + self.pre = bp.neurons.LIF(10) + self.post = bp.neurons.LIF(20) + self.syn = bp.TwoEndConn(self.pre, self.post, bp.conn.FixedProb(0.2)) + self.net = bp.Network(self.pre, self.post, self.syn) + + def test_net(self): + self.skipTest('Currently do not support') + with open('data/net.pickle', 'wb') as f: + pickle.dump(self.net, f) diff --git a/brainpy/_src/tests/test_slice_view.py b/brainpy/_src/tests/test_slice_view.py new file mode 100644 index 000000000..a952528fb --- /dev/null +++ b/brainpy/_src/tests/test_slice_view.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- + + +import brainpy as bp +import brainpy.math as bm +import unittest + + +class TestSliceView(unittest.TestCase): + def test_lif(self): + lif = bp.neurons.LIF(10) + lif_tile = lif[5:] + print(lif_tile.V.shape) + print(lif_tile.varshape) + + print('Before modification: ') + print(lif.V) + lif_tile.V += 10. + + self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 10.]))) + print('After modification 1: ') + print(lif.V) + + lif_tile.V.value = bm.ones(5) * 40. + self.assertTrue(bm.allclose(lif.V, bm.concatenate([bm.zeros(5), bm.ones(5) * 40.]))) + print('After modification 2: ') + print(lif.V) + + def test_lif_train_mode(self): + lif = bp.neurons.LIF(10, mode=bm.training_mode) + lif_tile = lif[5:] + print(lif_tile.V.shape) + print(lif_tile.varshape) + + print('Before modification: ') + print(lif.V) + lif_tile.V += 10. + + self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 10.]))) + print('After modification 1: ') + print(lif.V) + + lif_tile.V.value = bm.ones((1, 5)) * 40. + self.assertTrue(bm.allclose(lif.V, bm.hstack([bm.zeros((1, 5)), bm.ones((1, 5)) * 40.]))) + print('After modification 2: ') + print(lif.V) + + + + + diff --git a/brainpy/_src/tools/codes.py b/brainpy/_src/tools/codes.py index 01debfb20..4b809f80f 100644 --- a/brainpy/_src/tools/codes.py +++ b/brainpy/_src/tools/codes.py @@ -8,6 +8,7 @@ __all__ = [ + 'repr_dict', 'repr_object', 'repr_context', 'copy_doc', @@ -27,6 +28,11 @@ ] +def repr_dict(dict_obj: dict): + ret = [f'{k}={v}' for k, v in dict_obj.items()] + return ', '.join(ret) + + def repr_object(x): global BrainPyObject if BrainPyObject is None: diff --git a/brainpy/_src/tools/others.py b/brainpy/_src/tools/others.py index d945d890a..1c0fa3995 100644 --- a/brainpy/_src/tools/others.py +++ b/brainpy/_src/tools/others.py @@ -21,7 +21,7 @@ ] -def one_of(default: Any, *choices, names: Sequence[str] =None): +def one_of(default: Any, *choices, names: Sequence[str] = None): names = [f'arg{i}' for i in range(len(choices))] if names is None else names res = default has_chosen = False @@ -90,7 +90,7 @@ def to_size(x) -> Optional[Tuple[int]]: if isinstance(x, (tuple, list)): return tuple(x) if isinstance(x, (int, np.integer)): - return (x, ) + return (x,) if x is None: return x raise ValueError(f'Cannot make a size for {x}') @@ -183,3 +183,4 @@ def _progress_bar(iter_num): close_tqdm(iter_num) return _progress_bar + diff --git a/brainpy/check.py b/brainpy/check.py index ef3e35ae9..65756d1c9 100644 --- a/brainpy/check.py +++ b/brainpy/check.py @@ -3,6 +3,7 @@ from functools import wraps, partial from typing import Union, Sequence, Dict, Callable, Tuple, Type, Optional, Any +import jax import numpy as np import numpy as onp from jax import numpy as jnp @@ -251,7 +252,7 @@ def is_initializer( raise ValueError(f'{name} must be an initializer, but we got None.') if isinstance(initializer, init.Initializer): return initializer - elif isinstance(initializer, (Array, jnp.ndarray)): + elif isinstance(initializer, (Array, jax.Array)): return initializer elif callable(initializer): return initializer @@ -281,7 +282,7 @@ def is_connector( raise ValueError(f'{name} must be an initializer, but we got None.') if isinstance(connector, conn.Connector): return connector - elif isinstance(connector, (Array, jnp.ndarray)): + elif isinstance(connector, (Array, jax.Array)): return connector elif callable(connector): return connector diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index d026daacf..4909b00ea 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -34,6 +34,8 @@ from . import random # others +from . import sharding + import jax.numpy as jnp from jax import config diff --git a/brainpy/math/sharding.py b/brainpy/math/sharding.py new file mode 100644 index 000000000..9a61ca6c7 --- /dev/null +++ b/brainpy/math/sharding.py @@ -0,0 +1,7 @@ + +from brainpy._src.math.sharding import ( + set, + get_sharding, + partition_by_axname, + partition_by_sharding, +) diff --git a/brainpy/pnn/__init__.py b/brainpy/pnn/__init__.py new file mode 100644 index 000000000..19c8056f5 --- /dev/null +++ b/brainpy/pnn/__init__.py @@ -0,0 +1,7 @@ +from .pneurons import * +from .psynapses import * +from .pchannels import * +from .pother_models import * +from .putils import * +from . import mixin + diff --git a/brainpy/pnn/mixin.py b/brainpy/pnn/mixin.py new file mode 100644 index 000000000..7cd17c1bc --- /dev/null +++ b/brainpy/pnn/mixin.py @@ -0,0 +1,4 @@ +from brainpy._src.pnn.mixin import ( + ParamDesc, + AlignPost, +) diff --git a/brainpy/pnn/pchannels.py b/brainpy/pnn/pchannels.py new file mode 100644 index 000000000..e69de29bb diff --git a/brainpy/pnn/pneurons.py b/brainpy/pnn/pneurons.py new file mode 100644 index 000000000..88c5f449e --- /dev/null +++ b/brainpy/pnn/pneurons.py @@ -0,0 +1,11 @@ +from brainpy._src.pnn.neurons.base import ( + PNeuGroup, + DPNeuGroup, +) +from brainpy._src.pnn.neurons.lif import ( + IF, + LIF, + LIFLtc, + LIFRef, + LIFRefLtc, +) diff --git a/brainpy/pnn/pother_models.py b/brainpy/pnn/pother_models.py new file mode 100644 index 000000000..75ab179f4 --- /dev/null +++ b/brainpy/pnn/pother_models.py @@ -0,0 +1,8 @@ +from brainpy._src.pnn.common import ( + Leaky, + Integrator, +) +from brainpy._src.pnn.delay import ( + TargetDelay, + DataDelay, +) diff --git a/brainpy/pnn/psynapses.py b/brainpy/pnn/psynapses.py new file mode 100644 index 000000000..f0f661e3d --- /dev/null +++ b/brainpy/pnn/psynapses.py @@ -0,0 +1,28 @@ +from brainpy._src.pnn.synapses.projections import ( + ProjectionAlignPost, + ProjectionAlignPre, +) + +from brainpy._src.pnn.synapses.syn_dynamics import ( + Exponential, + DualExponential, + Alpha, + STD, + STP, + AMPA, + GABAa, + BioNMDA, +) + +from brainpy._src.pnn.synapses.syn_output import ( + COBA, + CUBA, + MgBlock +) + +from brainpy._src.pnn.synapses.syn_comm import ( + All2allMM, + One2oneMM, + DenseMM, +) + diff --git a/brainpy/pnn/putils.py b/brainpy/pnn/putils.py new file mode 100644 index 000000000..98ce931cb --- /dev/null +++ b/brainpy/pnn/putils.py @@ -0,0 +1,6 @@ +from brainpy._src.pnn.utils.axis_names import (NEU_AXIS, + POST_AXIS, + PRE_AXIS, + SYN_AXIS, + BATCH_AXIS, ) +from brainpy._src.pnn.utils.init import (DelayedInit, ) diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py index 12cb9caa6..60cff2bb1 100644 --- a/examples/dynamics_simulation/COBA.py +++ b/examples/dynamics_simulation/COBA.py @@ -27,19 +27,19 @@ def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None): wi = 6.7 / scale # inhibitory synaptic weight self.E2E = bp.experimental.Exponential( bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size), - g_max=we, tau=5., out=bp.experimental.COBA(E=0.) + g_max=we, tau=5., out=bp.experimental.COBA(E=0.), comp_method='dense' ) self.E2I = bp.experimental.Exponential( bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), - g_max=we, tau=5., out=bp.experimental.COBA(E=0.) + g_max=we, tau=5., out=bp.experimental.COBA(E=0.), comp_method='dense' ) self.I2E = bp.experimental.Exponential( bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size), - g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.), comp_method='dense' ) self.I2I = bp.experimental.Exponential( bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size), - g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.), comp_method='dense' ) self.delayE = bp.Delay(self.E.spike, entries={'E': delay}) self.delayI = bp.Delay(self.I.spike, entries={'I': delay}) @@ -102,19 +102,24 @@ def update(self): # simulation - -@pmap -def f2(I): - net = EINet(delay=0., scale=5., e_input=I, i_input=I) - # net = EINetv2(delay=0., scale=2.) - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, numpy_mon_after_run=False) - runner.run(10000.) - return runner.mon - # print(r) - # bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) - - -print(f2(bm.ones(20) * 20.)) +net = EINet(delay=0., scale=1.) +runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}) +runner.run(100.) +# print(r) +bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) + +# @pmap +# def f2(I): +# net = EINet(delay=0., scale=5., e_input=I, i_input=I) +# # net = EINetv2(delay=0., scale=2.) +# runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, numpy_mon_after_run=False) +# runner.run(10000.) +# return runner.mon +# # print(r) +# # bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) +# +# +# print(f2(bm.ones(20) * 20.)) From f303277baaccfe0af3a39953d35aa0ae5ba0a98a Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 22:59:17 +0800 Subject: [PATCH 12/14] [test] update tests --- .gitignore | 1 + brainpy/_src/layers/tests/__init__.py | 0 .../{test_conv.py => test_conv_layers.py} | 0 ...test_pooling.py => test_pooling_layers.py} | 49 ++++++++++--------- brainpy/_src/math/delayvars.py | 3 +- .../math/object_transform/tests/test_base.py | 9 ++-- .../op_registers/numba_approach/__init__.py | 1 - .../{numba_approach => tests}/test_ei_net.py | 20 ++++---- brainpy/_src/math/remove_vmap.py | 3 +- .../_src/neurons/tests/test_reduced_models.py | 2 +- 10 files changed, 46 insertions(+), 42 deletions(-) delete mode 100644 brainpy/_src/layers/tests/__init__.py rename brainpy/_src/layers/tests/{test_conv.py => test_conv_layers.py} (100%) rename brainpy/_src/layers/tests/{test_pooling.py => test_pooling_layers.py} (68%) rename brainpy/_src/math/op_registers/{numba_approach => tests}/test_ei_net.py (86%) diff --git a/.gitignore b/.gitignore index 3c7f48499..dec4fa91d 100644 --- a/.gitignore +++ b/.gitignore @@ -224,3 +224,4 @@ cython_debug/ /examples/training_snn_models/data/ /docs/tutorial_advanced/data/ /my_tests/ +/examples/dynamics_simulation/Joglekar_2018_data/ diff --git a/brainpy/_src/layers/tests/__init__.py b/brainpy/_src/layers/tests/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/brainpy/_src/layers/tests/test_conv.py b/brainpy/_src/layers/tests/test_conv_layers.py similarity index 100% rename from brainpy/_src/layers/tests/test_conv.py rename to brainpy/_src/layers/tests/test_conv_layers.py diff --git a/brainpy/_src/layers/tests/test_pooling.py b/brainpy/_src/layers/tests/test_pooling_layers.py similarity index 68% rename from brainpy/_src/layers/tests/test_pooling.py rename to brainpy/_src/layers/tests/test_pooling_layers.py index 56db78f60..347d49184 100644 --- a/brainpy/_src/layers/tests/test_pooling.py +++ b/brainpy/_src/layers/tests/test_pooling_layers.py @@ -7,7 +7,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.layers import pooling class TestPool(parameterized.TestCase): @@ -61,43 +60,43 @@ def test_avgpool(self): def test_MaxPool2d_v1(self): arr = self.rng.rand(16, 32, 32, 8) - out = pooling.MaxPool2d(2, 2, channel_axis=-1)(arr) + out = bp.layers.MaxPool2d(2, 2, channel_axis=-1)(arr) self.assertTrue(out.shape == (16, 16, 16, 8)) - out = pooling.MaxPool2d(2, 2, channel_axis=None)(arr) + out = bp.layers.MaxPool2d(2, 2, channel_axis=None)(arr) self.assertTrue(out.shape == (16, 32, 16, 4)) - out = pooling.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr) + out = bp.layers.MaxPool2d(2, 2, channel_axis=None, padding=1)(arr) self.assertTrue(out.shape == (16, 32, 17, 5)) - out = pooling.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) + out = bp.layers.MaxPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) self.assertTrue(out.shape == (16, 32, 18, 5)) - out = pooling.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) + out = bp.layers.MaxPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) self.assertTrue(out.shape == (16, 17, 17, 8)) - out = pooling.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) + out = bp.layers.MaxPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) self.assertTrue(out.shape == (16, 17, 32, 5)) def test_AvgPool2d_v1(self): arr = self.rng.rand(16, 32, 32, 8) - out = pooling.AvgPool2d(2, 2, channel_axis=-1)(arr) + out = bp.layers.AvgPool2d(2, 2, channel_axis=-1)(arr) self.assertTrue(out.shape == (16, 16, 16, 8)) - out = pooling.AvgPool2d(2, 2, channel_axis=None)(arr) + out = bp.layers.AvgPool2d(2, 2, channel_axis=None)(arr) self.assertTrue(out.shape == (16, 32, 16, 4)) - out = pooling.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr) + out = bp.layers.AvgPool2d(2, 2, channel_axis=None, padding=1)(arr) self.assertTrue(out.shape == (16, 32, 17, 5)) - out = pooling.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) + out = bp.layers.AvgPool2d(2, 2, channel_axis=None, padding=(2, 1))(arr) self.assertTrue(out.shape == (16, 32, 18, 5)) - out = pooling.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) + out = bp.layers.AvgPool2d(2, 2, channel_axis=-1, padding=(1, 1))(arr) self.assertTrue(out.shape == (16, 17, 17, 8)) - out = pooling.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) + out = bp.layers.AvgPool2d(2, 2, channel_axis=2, padding=(1, 1))(arr) self.assertTrue(out.shape == (16, 17, 32, 5)) @parameterized.named_parameters( @@ -106,46 +105,48 @@ def test_AvgPool2d_v1(self): for target_size in [10, 9, 8, 7, 6] ) def test_adaptive_pool1d(self, target_size): + from brainpy._src.layers.pooling import _adaptive_pool1d + arr = self.rng.rand(100) op = jax.numpy.mean - out = pooling._adaptive_pool1d(arr, target_size, op) + out = _adaptive_pool1d(arr, target_size, op) print(out.shape) self.assertTrue(out.shape == (target_size,)) - out = pooling._adaptive_pool1d(arr, target_size, op) + out = _adaptive_pool1d(arr, target_size, op) print(out.shape) self.assertTrue(out.shape == (target_size,)) def test_AdaptiveAvgPool2d_v1(self): input = self.rng.randn(64, 8, 9) - output = pooling.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) + output = bp.layers.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) self.assertTrue(output.shape == (64, 5, 7)) - output = pooling.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) + output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) self.assertTrue(output.shape == (64, 2, 3)) - output = pooling.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) + output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) self.assertTrue(output.shape == (2, 3, 9)) - output = pooling.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) + output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) self.assertTrue(output.shape == (2, 8, 3)) - output = pooling.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input) + output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=None)(input) self.assertTrue(output.shape == (64, 2, 3)) def test_AdaptiveAvgPool2d_v2(self): input = self.rng.randn(128, 64, 32, 16) - output = pooling.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) + output = bp.layers.AdaptiveAvgPool2d((5, 7), channel_axis=0)(input) self.assertTrue(output.shape == (128, 64, 5, 7)) - output = pooling.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) + output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=0)(input) self.assertTrue(output.shape == (128, 64, 2, 3)) - output = pooling.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) + output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=-1)(input) self.assertTrue(output.shape == (128, 2, 3, 16)) - output = pooling.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) + output = bp.layers.AdaptiveAvgPool2d((2, 3), channel_axis=1)(input) self.assertTrue(output.shape == (128, 64, 2, 3)) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index cff6d09c4..b5b0c7f08 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from typing import Union, Callable +import numbers import jax import jax.numpy as jnp @@ -447,7 +448,7 @@ def retrieve(self, delay_len, *indices): # the delay data return self.data[indices] - def update(self, value: Union[float, int, bool, Array, jnp.DeviceArray]): + def update(self, value: Union[numbers.Number, Array, jax.Array]): """Update delay variable with the new data. Parameters diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index 6165d248a..b3762865d 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -100,7 +100,7 @@ def update(self, x): with bm.environment(mode=bm.NonBatchingMode()): obj = Object() - self.assertTrue(len(obj.vars()) == 1) + self.assertTrue(len(obj.vars()) == 0) self.assertTrue(len(obj.nodes()) == 7) print(obj.nodes().keys()) @@ -110,7 +110,7 @@ def update(self, x): with bm.environment(mode=bm.TrainingMode()): obj = Object() - self.assertTrue(len(obj.vars()) == 7) + self.assertTrue(len(obj.vars()) == 6) self.assertTrue(len(obj.nodes()) == 7) print(obj.nodes().keys()) @@ -144,7 +144,8 @@ def update(self, x): with bm.environment(mode=bm.NonBatchingMode()): obj = Object() - self.assertTrue(len(obj.vars()) == 1) + + self.assertTrue(len(obj.vars()) == 0) self.assertTrue(len(obj.nodes()) == 7) self.assertTrue(len(jax.tree_util.tree_leaves(obj)) == 1) @@ -155,7 +156,7 @@ def update(self, x): with bm.environment(mode=bm.TrainingMode()): obj = Object() - self.assertTrue(len(obj.vars()) == 7) + self.assertTrue(len(obj.vars()) == 6) self.assertTrue(len(obj.nodes()) == 7) print(obj.nodes().keys()) diff --git a/brainpy/_src/math/op_registers/numba_approach/__init__.py b/brainpy/_src/math/op_registers/numba_approach/__init__.py index 37485b550..3856b9873 100644 --- a/brainpy/_src/math/op_registers/numba_approach/__init__.py +++ b/brainpy/_src/math/op_registers/numba_approach/__init__.py @@ -109,7 +109,6 @@ def __call__(self, *args, **kwargs): return res - def register_op_with_numba( op_name: str, cpu_func: Callable, diff --git a/brainpy/_src/math/op_registers/numba_approach/test_ei_net.py b/brainpy/_src/math/op_registers/tests/test_ei_net.py similarity index 86% rename from brainpy/_src/math/op_registers/numba_approach/test_ei_net.py rename to brainpy/_src/math/op_registers/tests/test_ei_net.py index d883c3965..4f3da1596 100644 --- a/brainpy/_src/math/op_registers/numba_approach/test_ei_net.py +++ b/brainpy/_src/math/op_registers/tests/test_ei_net.py @@ -1,13 +1,13 @@ import brainpy.math as bm import brainpy as bp -from jax.abstract_arrays import ShapedArray +from jax.core import ShapedArray bm.set_platform('cpu') def abs_eval(events, indices, indptr, *, weight, post_num): - return ShapedArray((post_num,), bm.float32) + return [ShapedArray((post_num,), bm.float32), ] def con_compute(outs, ins): @@ -22,7 +22,7 @@ def con_compute(outs, ins): post_val[index] += weight -event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute, apply_cpu_func_to_gpu=True) +event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute) class ExponentialV2(bp.TwoEndConn): @@ -52,13 +52,14 @@ def update(self, tdi): self.pre2post[0], self.pre2post[1], weight=self.g_max, - post_num=self.post.num) + post_num=self.post.num)[0] self.post.input += self.g * (self.E - self.post.V) class EINet(bp.Network): def __init__(self, scale): # neurons + bm.random.seed() pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) E = bp.neurons.LIF(int(3200 * scale), **pars, method='exp_auto') @@ -73,10 +74,11 @@ def __init__(self, scale): super(EINet, self).__init__(E2E, E2I, I2E, I2I, E=E, I=I) -# def test1(): -# net2 = EINet(scale=0.1) -# runner2 = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) -# r = runner2.predict(100., eval_time=True) -# print(r) +def test1(): + net2 = EINet(scale=0.1) + runner2 = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) + r = runner2.predict(100., eval_time=True) + print(r) + diff --git a/brainpy/_src/math/remove_vmap.py b/brainpy/_src/math/remove_vmap.py index 6075b9452..ee81c0c17 100644 --- a/brainpy/_src/math/remove_vmap.py +++ b/brainpy/_src/math/remove_vmap.py @@ -2,8 +2,7 @@ import jax.numpy as jnp -from jax.abstract_arrays import ShapedArray -from jax.core import Primitive +from jax.core import Primitive, ShapedArray from jax.interpreters import batching, mlir, xla from .ndarray import Array diff --git a/brainpy/_src/neurons/tests/test_reduced_models.py b/brainpy/_src/neurons/tests/test_reduced_models.py index 2d88efd92..29b0b1247 100644 --- a/brainpy/_src/neurons/tests/test_reduced_models.py +++ b/brainpy/_src/neurons/tests/test_reduced_models.py @@ -3,7 +3,7 @@ import brainpy as bp from absl.testing import parameterized -from brainpy._src.dyn.neurons import reduced_models +from brainpy._src.neurons import reduced_models class TestNoise(parameterized.TestCase): From 0ad84ba4d25991f56b2ff3f62cdb632ed2e88941 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 23:04:24 +0800 Subject: [PATCH 13/14] [CI] update CI tests --- .github/workflows/CI.yml | 209 +++++++++++++++++++-------------------- 1 file changed, 104 insertions(+), 105 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index a358808f3..2324a5d05 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.8", "3.9", "3.10" ] + python-version: [ "3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 @@ -42,45 +42,45 @@ jobs: cd examples pytest ../brainpy/ - test_linux_py37: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.7"] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi - pip install jax==0.3.25 - pip install jaxlib==0.3.25 - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd examples - pytest ../brainpy/ - +# test_linux_py37: +# runs-on: ubuntu-latest +# strategy: +# fail-fast: false +# matrix: +# python-version: ["3.7"] +# +# steps: +# - uses: actions/checkout@v2 +# - name: Set up Python ${{ matrix.python-version }} +# uses: actions/setup-python@v2 +# with: +# python-version: ${{ matrix.python-version }} +# - name: Install dependencies +# run: | +# python -m pip install --upgrade pip +# python -m pip install flake8 pytest +# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi +# pip install jax==0.3.25 +# pip install jaxlib==0.3.25 +# pip uninstall brainpy -y +# python setup.py install +# - name: Lint with flake8 +# run: | +# # stop the build if there are Python syntax errors or undefined names +# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics +# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide +# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +# - name: Test with pytest +# run: | +# cd examples +# pytest ../brainpy/ +# test_macos: runs-on: macos-latest strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 @@ -106,46 +106,46 @@ jobs: cd examples pytest ../brainpy/ - test_macos_py37: - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - python-version: [ "3.7" ] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi - pip install jax==0.3.25 - pip install jaxlib==0.3.25 - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd examples - pytest ../brainpy/ - +# test_macos_py37: +# runs-on: macos-latest +# strategy: +# fail-fast: false +# matrix: +# python-version: [ "3.7" ] +# +# steps: +# - uses: actions/checkout@v2 +# - name: Set up Python ${{ matrix.python-version }} +# uses: actions/setup-python@v2 +# with: +# python-version: ${{ matrix.python-version }} +# - name: Install dependencies +# run: | +# python -m pip install --upgrade pip +# python -m pip install flake8 pytest +# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi +# pip install jax==0.3.25 +# pip install jaxlib==0.3.25 +# pip uninstall brainpy -y +# python setup.py install +# - name: Lint with flake8 +# run: | +# # stop the build if there are Python syntax errors or undefined names +# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics +# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide +# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +# - name: Test with pytest +# run: | +# cd examples +# pytest ../brainpy/ +# test_windows: runs-on: windows-latest strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - uses: actions/checkout@v2 @@ -158,8 +158,7 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest python -m pip install numpy>=1.21.0 - python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver - python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz + python -m pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver python -m pip install -r requirements-dev.txt python -m pip install tqdm brainpylib pip uninstall brainpy -y @@ -175,37 +174,37 @@ jobs: cd examples pytest ../brainpy/ - test_windows_py37: - runs-on: windows-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.7"] - - steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - python -m pip install numpy>=1.21.0 - python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver - python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz - python -m pip install -r requirements-dev.txt - python -m pip install tqdm brainpylib - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd examples - pytest ../brainpy/ \ No newline at end of file +# test_windows_py37: +# runs-on: windows-latest +# strategy: +# fail-fast: false +# matrix: +# python-version: ["3.7"] +# +# steps: +# - uses: actions/checkout@v2 +# - name: Set up Python ${{ matrix.python-version }} +# uses: actions/setup-python@v2 +# with: +# python-version: ${{ matrix.python-version }} +# - name: Install dependencies +# run: | +# python -m pip install --upgrade pip +# python -m pip install flake8 pytest +# python -m pip install numpy>=1.21.0 +# python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver +# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz +# python -m pip install -r requirements-dev.txt +# python -m pip install tqdm brainpylib +# pip uninstall brainpy -y +# python setup.py install +# - name: Lint with flake8 +# run: | +# # stop the build if there are Python syntax errors or undefined names +# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics +# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide +# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics +# - name: Test with pytest +# run: | +# cd examples +# pytest ../brainpy/ \ No newline at end of file From 480842ec80a69df9209d98c10a15102aec3ce292 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 11 Jun 2023 23:10:04 +0800 Subject: [PATCH 14/14] [CI] update CI tests --- .github/workflows/CI.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 2324a5d05..3f065c501 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -158,7 +158,8 @@ jobs: python -m pip install --upgrade pip python -m pip install flake8 pytest python -m pip install numpy>=1.21.0 - python -m pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver + python -m pip install "jaxlib==0.4.10" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver + python -m pip install jax==0.4.10 python -m pip install -r requirements-dev.txt python -m pip install tqdm brainpylib pip uninstall brainpy -y