From 1e90669a6edf11b0dfdf4c54b429dd6f31b56c1a Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 28 Oct 2023 20:24:54 +0800 Subject: [PATCH 1/4] [math] numpy apis compatability --- brainpy/_src/math/compat_numpy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index d8da11c9e..305cd5987 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -381,7 +381,7 @@ def msort(a): nansum = _compatible_with_brainpy_array(jnp.nansum) ediff1d = _compatible_with_brainpy_array(jnp.ediff1d) cross = _compatible_with_brainpy_array(jnp.cross) -trapz = _compatible_with_brainpy_array(jnp.trapz) +trapz = _compatible_with_brainpy_array(jax.scipy.integrate.trapezoid) isfinite = _compatible_with_brainpy_array(jnp.isfinite) isinf = _compatible_with_brainpy_array(jnp.isinf) isnan = _compatible_with_brainpy_array(jnp.isnan) @@ -640,7 +640,7 @@ def size(a, axis=None): isposinf = _compatible_with_brainpy_array(jnp.isposinf) isrealobj = _compatible_with_brainpy_array(jnp.isrealobj) issubdtype = jnp.issubdtype -issubsctype = jnp.issubsctype +issubsctype = jnp.issubdtype iterable = _compatible_with_brainpy_array(jnp.iterable) packbits = _compatible_with_brainpy_array(jnp.packbits) piecewise = _compatible_with_brainpy_array(jnp.piecewise) From 744dce9bb5389232d294a9e72e0be1db8c9f5994 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 28 Oct 2023 20:25:13 +0800 Subject: [PATCH 2/4] [test] test compatibility --- .../_src/dyn/projections/tests/test_STDP.py | 24 ++++++++++++------- .../_src/dyn/rates/tests/test_reservoir.py | 2 +- brainpy/_src/dyn/rates/tests/test_rnncells.py | 24 +++++++++---------- brainpy/_src/tests/test_dyn_runner.py | 6 ++--- brainpy/_src/tests/test_pickle.py | 4 ++-- 5 files changed, 34 insertions(+), 26 deletions(-) diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index 457e97e51..e33644f26 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -import os -os.environ['JAX_TRACEBACK_FILTERING'] = 'off' +import matplotlib.pyplot as plt +import numpy as np from absl.testing import parameterized import brainpy as bp @@ -20,8 +20,9 @@ def __init__(self, num_pre, num_post): self.syn = bp.dyn.STDP_Song2000( pre=self.pre, delay=1., - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=lambda s: bm.Variable(bm.random.rand(*s) * 0.1)), + # comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + # weight=bp.init.Uniform(-0.1, 0.1)), + comm=bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(-0.1, 0.1)), syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), out=bp.dyn.COBA.desc(E=0.), post=self.post, @@ -39,7 +40,7 @@ def update(self, I_pre, I_post): Apre = self.syn.refs['pre_trace'].g Apost = self.syn.refs['post_trace'].g current = self.post.sum_inputs(self.post.V) - return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight + return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight.flatten() duration = 300. I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], @@ -53,7 +54,14 @@ def run(i, I_pre, I_post): pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) return pre_spike, post_spike, g, Apre, Apost, current, W - indices = bm.arange(0, duration, bm.dt) - bm.for_loop(run, [indices, I_pre, I_post], jit=True) - bm.clear_buffer_memory() + indices = np.arange(int(duration / bm.dt)) + pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) + + fig, gs = bp.visualize.get_figure(4, 1, 3, 10) + bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) + bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) + bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) + bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) + plt.show() + bm.clear_buffer_memory() diff --git a/brainpy/_src/dyn/rates/tests/test_reservoir.py b/brainpy/_src/dyn/rates/tests/test_reservoir.py index 7a1dd2343..371c7aa89 100644 --- a/brainpy/_src/dyn/rates/tests/test_reservoir.py +++ b/brainpy/_src/dyn/rates/tests/test_reservoir.py @@ -15,7 +15,7 @@ class Test_Reservoir(parameterized.TestCase): def test_Reservoir(self, mode): bm.random.seed() input = bm.random.randn(10, 3) - layer = bp.dnn.Reservoir(input_shape=3, + layer = bp.syn.Reservoir(input_shape=3, num_out=5, mode=mode) if mode in [bm.NonBatchingMode()]: diff --git a/brainpy/_src/dyn/rates/tests/test_rnncells.py b/brainpy/_src/dyn/rates/tests/test_rnncells.py index a55e958c3..5f86288f7 100644 --- a/brainpy/_src/dyn/rates/tests/test_rnncells.py +++ b/brainpy/_src/dyn/rates/tests/test_rnncells.py @@ -15,7 +15,7 @@ class Test_Rnncells(parameterized.TestCase): def test_RNNCell(self, mode): bm.random.seed() input = bm.random.randn(20, 10) - layer = bp.dnn.RNNCell(num_in=10, + layer = bp.dyn.RNNCell(num_in=10, num_out=64, mode=mode ) @@ -25,7 +25,7 @@ def test_RNNCell(self, mode): def test_RNNCell_NonBatching(self): bm.random.seed() input = bm.random.randn(10) - layer = bp.dnn.RNNCell(num_in=10, + layer = bp.dyn.RNNCell(num_in=10, num_out=32, mode=bm.NonBatchingMode()) output = layer(input) @@ -41,7 +41,7 @@ def test_RNNCell_NonBatching(self): def test_GRUCell(self, mode): bm.random.seed() input = bm.random.randn(50, 100) - layer = bp.dnn.GRUCell(num_in=100, + layer = bp.dyn.GRUCell(num_in=100, num_out=64, mode=mode) output = layer(input) @@ -50,7 +50,7 @@ def test_GRUCell(self, mode): def test_GRUCell_NonBatching(self): bm.random.seed() input = bm.random.randn(10) - layer = bp.dnn.GRUCell(num_in=10, + layer = bp.dyn.GRUCell(num_in=10, num_out=12, mode=bm.NonBatchingMode()) output = layer(input) @@ -66,7 +66,7 @@ def test_GRUCell_NonBatching(self): def test_LSTMCell(self, mode): bm.random.seed() input = bm.random.randn(50, 100) - layer = bp.dnn.LSTMCell(num_in=100, + layer = bp.dyn.LSTMCell(num_in=100, num_out=64, mode=mode) @@ -76,7 +76,7 @@ def test_LSTMCell(self, mode): def test_LSTMCell_NonBatching(self): bm.random.seed() input = bm.random.randn(10) - layer = bp.dnn.LSTMCell(num_in=10, + layer = bp.dyn.LSTMCell(num_in=10, num_out=5, mode=bm.NonBatchingMode()) output = layer(input) @@ -91,7 +91,7 @@ def test_LSTMCell_NonBatching(self): def test_Conv1dLSTMCell(self, mode): bm.random.seed() input = bm.random.randn(4, 100, 3) - layer = bp.dnn.Conv1dLSTMCell(input_shape=(100,), + layer = bp.dyn.Conv1dLSTMCell(input_shape=(100,), in_channels=3, out_channels=5, kernel_size=4, @@ -102,7 +102,7 @@ def test_Conv1dLSTMCell(self, mode): def test_Conv1dLSTMCell_NonBatching(self): bm.random.seed() input = bm.random.randn(10, 3) - layer = bp.dnn.Conv1dLSTMCell(input_shape=(10,), + layer = bp.dyn.Conv1dLSTMCell(input_shape=(10,), in_channels=3, out_channels=4, kernel_size=5, @@ -119,7 +119,7 @@ def test_Conv1dLSTMCell_NonBatching(self): def test_Conv2dLSTMCell(self, mode): bm.random.seed() input = bm.random.randn(4, 100, 100, 3) - layer = bp.dnn.Conv2dLSTMCell(input_shape=(100, 100), + layer = bp.dyn.Conv2dLSTMCell(input_shape=(100, 100), in_channels=3, out_channels=5, kernel_size=(4, 4), @@ -130,7 +130,7 @@ def test_Conv2dLSTMCell(self, mode): def test_Conv2dLSTMCell_NonBatching(self): bm.random.seed() input = bm.random.randn(10, 10, 3) - layer = bp.dnn.Conv2dLSTMCell(input_shape=(10, 10), + layer = bp.dyn.Conv2dLSTMCell(input_shape=(10, 10), in_channels=3, out_channels=4, kernel_size=5, @@ -147,7 +147,7 @@ def test_Conv2dLSTMCell_NonBatching(self): def test_Conv3dLSTMCell(self, mode): bm.random.seed() input = bm.random.randn(4, 100, 100, 100, 3) - layer = bp.dnn.Conv3dLSTMCell(input_shape=(100, 100, 100), + layer = bp.dyn.Conv3dLSTMCell(input_shape=(100, 100, 100), in_channels=3, out_channels=5, kernel_size=(4, 4, 4), @@ -158,7 +158,7 @@ def test_Conv3dLSTMCell(self, mode): def test_Conv3dLSTMCell_NonBatching(self): bm.random.seed() input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.Conv3dLSTMCell(input_shape=(10, 10, 10), + layer = bp.dyn.Conv3dLSTMCell(input_shape=(10, 10, 10), in_channels=3, out_channels=4, kernel_size=5, diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py index 0cc2bb90c..dd6865e64 100644 --- a/brainpy/_src/tests/test_dyn_runner.py +++ b/brainpy/_src/tests/test_dyn_runner.py @@ -13,7 +13,7 @@ def __init__(self): super(ExampleDS, self).__init__() self.i = bm.Variable(bm.zeros(1)) - def update(self, tdi): + def update(self): self.i += 1 ds = ExampleDS() @@ -26,8 +26,8 @@ def __init__(self): super(ExampleDS, self).__init__() self.i = bm.Variable(bm.zeros(1)) - def update(self, tdi): - self.i += 1 * tdi.dt + def update(self): + self.i += 1 * bp.share['dt'] runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) runner.run(100.) diff --git a/brainpy/_src/tests/test_pickle.py b/brainpy/_src/tests/test_pickle.py index 2ae6a1345..bc2c77f1c 100644 --- a/brainpy/_src/tests/test_pickle.py +++ b/brainpy/_src/tests/test_pickle.py @@ -13,8 +13,8 @@ def __init__(self, *args, **kwargs): self.pre = bp.neurons.LIF(10) self.post = bp.neurons.LIF(20) - self.syn = bp.TwoEndConn(self.pre, self.post, bp.conn.FixedProb(0.2)) - self.net = bp.Network(self.pre, self.post, self.syn) + self.syn = bp.synapses.TwoEndConn(self.pre, self.post, bp.conn.FixedProb(0.2)) + self.net = bp.DynSysGroup(self.pre, self.post, self.syn) def test_net(self): self.skipTest('Currently do not support') From e5a5830ff63b5f6a7926bc17ebea48b66a705ed1 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 30 Oct 2023 14:28:40 +0800 Subject: [PATCH 3/4] [math] the interface for operator registration --- brainpy/_src/math/__init__.py | 2 +- brainpy/_src/math/event/_csr_matvec.py | 4 +- brainpy/_src/math/event/_info_collection.py | 2 +- brainpy/_src/math/jitconn/_event_matvec.py | 2 +- brainpy/_src/math/jitconn/_matvec.py | 2 +- .../{op_registers => op_register}/__init__.py | 3 +- brainpy/_src/math/op_register/base.py | 208 ++++++++++++++++++ .../numba_approach/__init__.py | 87 -------- .../numba_approach/cpu_translation.py | 0 brainpy/_src/math/op_register/numba_based.py | 115 ++++++++++ brainpy/_src/math/op_register/taichi_based.py | 9 + .../tests/test_ei_net.py | 0 .../{op_registers => op_register}/utils.py | 0 brainpy/_src/math/sparse/_bsr_mm.py | 4 +- brainpy/_src/math/sparse/_bsr_mv.py | 4 +- brainpy/_src/math/sparse/_coo_mv.py | 2 +- brainpy/_src/math/sparse/_csr_mv.py | 4 +- brainpy/_src/math/sparse/_utils.py | 2 +- brainpy/math/op_register.py | 6 +- 19 files changed, 351 insertions(+), 105 deletions(-) rename brainpy/_src/math/{op_registers => op_register}/__init__.py (64%) create mode 100644 brainpy/_src/math/op_register/base.py rename brainpy/_src/math/{op_registers => op_register}/numba_approach/__init__.py (68%) rename brainpy/_src/math/{op_registers => op_register}/numba_approach/cpu_translation.py (100%) create mode 100644 brainpy/_src/math/op_register/numba_based.py create mode 100644 brainpy/_src/math/op_register/taichi_based.py rename brainpy/_src/math/{op_registers => op_register}/tests/test_ei_net.py (100%) rename brainpy/_src/math/{op_registers => op_register}/utils.py (100%) diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 208f378e1..5158d8c1e 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -49,7 +49,7 @@ from . import random, linalg, fft # operators -from .op_registers import * +from .op_register import * from .pre_syn_post import * from .surrogate._compt import * from . import surrogate, event, sparse, jitconn diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 377007847..a30421e4b 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -23,8 +23,8 @@ from jax.lib import xla_client from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba, - register_general_batching) +from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, + register_general_batching) from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py index f355d3658..4f350e225 100644 --- a/brainpy/_src/math/event/_info_collection.py +++ b/brainpy/_src/math/event/_info_collection.py @@ -10,7 +10,7 @@ from jax.lib import xla_client from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_registers import register_op_with_numba +from brainpy._src.math.op_register import register_op_with_numba from brainpy.errors import GPUOperatorNotFound from brainpy._src.math.ndarray import Array diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index af0e9dabe..e627c43a1 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -18,7 +18,7 @@ mv_prob_uniform, mv_prob_normal) from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_registers import register_general_batching +from brainpy._src.math.op_register import register_general_batching from brainpy.errors import GPUOperatorNotFound try: diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 336ee896c..714256e12 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -14,7 +14,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_registers import register_general_batching +from brainpy._src.math.op_register import register_general_batching from brainpy.errors import GPUOperatorNotFound, MathError try: diff --git a/brainpy/_src/math/op_registers/__init__.py b/brainpy/_src/math/op_register/__init__.py similarity index 64% rename from brainpy/_src/math/op_registers/__init__.py rename to brainpy/_src/math/op_register/__init__.py index 3628c3279..4d5acf26a 100644 --- a/brainpy/_src/math/op_registers/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -1,6 +1,5 @@ -from .numba_approach import (XLACustomOp, - CustomOpByNumba, +from .numba_approach import (CustomOpByNumba, register_op_with_numba, compile_cpu_signature_with_numba) from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py new file mode 100644 index 000000000..12871ad8e --- /dev/null +++ b/brainpy/_src/math/op_register/base.py @@ -0,0 +1,208 @@ +from functools import partial +from typing import Callable, Sequence, Tuple, Protocol, Optional + +import jax +import numpy as np +from jax.interpreters import xla, batching, ad, mlir +from numba.core.dispatcher import Dispatcher + +from brainpy._src.math.ndarray import Array +from brainpy._src.math.object_transform.base import BrainPyObject +from .numba_based import register_numba_cpu_translation_rule +from .taichi_based import (register_taichi_cpu_translation_rule, + register_taichi_gpu_translation_rule) +from .utils import register_general_batching + +__all__ = [ + 'XLACustomOp', +] + + +class ShapeDtype(Protocol): + + @property + def shape(self) -> Tuple[int, ...]: + ... + + @property + def dtype(self) -> np.dtype: + ... + + +class XLACustomOp(BrainPyObject): + """Creating a XLA custom call operator. + + >>> import numba as nb + >>> import taichi as ti + >>> import numpy as np + >>> import jax + >>> + >>> @nb.njit + >>> def numba_cpu_fun(a, b, out_a, out_b): + >>> out_a[:] = a + >>> out_b[:] = b + >>> + >>> @ti.kernel + >>> def taichi_gpu_fun(a, b, out_a, out_b): + >>> for i in range(a.size): + >>> out_a[i] = a[i] + >>> for i in range(b.size): + >>> out_b[i] = b[i] + >>> + >>> # option 1 + >>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun) + >>> a2, b2 = prim(np.random.random(1000), np.random.random(1000), + >>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32), + >>> jax.ShapeDtypeStruct(1000, dtype=np.float32)]) + >>> + >>> # option 2 + >>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun, + >>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32), + >>> jax.ShapeDtypeStruct(1000, dtype=np.float32)]) + >>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000)) + + Args: + cpu_kernel: Callable. The function defines the computation on CPU backend. + gpu_kernel: Callable. The function defines the computation on GPU backend. + batching_translation: Callable. The batching translation rule of JAX. + jvp_translation: Callable. The JVP translation rule of JAX. + transpose_translation: Callable. The transpose translation rule of JAX. + outs: optional, sequence of `ShapeDtype`. The output information. + name: str. The primitive name. + """ + + def __init__( + self, + cpu_kernel: Callable = None, + gpu_kernel: Callable = None, + batching_translation: Callable = None, + jvp_translation: Callable = None, + transpose_translation: Callable = None, + outs: Optional[Sequence[ShapeDtype]] = None, + name: str = None, + ): + super().__init__(name) + + # primitive + self.primitive = jax.core.Primitive(self.name) + self.primitive.multiple_results = True + + # abstract evaluation + if outs is not None: + outs = tuple([_transform_to_shapedarray(o) for o in outs]) + self.outs = outs + self.primitive.def_abstract_eval(self._abstract_eval) + self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) + + # cpu function + if cpu_kernel is None: + pass + elif isinstance(cpu_kernel, Dispatcher): # numba + register_numba_cpu_translation_rule(self.primitive, cpu_kernel) + elif hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi + register_taichi_cpu_translation_rule(self.primitive, cpu_kernel) + else: + raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. ' + f'But we got {cpu_kernel}') + + # gpu function + if gpu_kernel is None: + pass + elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi + register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) + else: + raise ValueError(f'"cpu_kernel" must be a taichi kernel function. ' + f'But we got {gpu_kernel}') + + # batching rule + if batching_translation is None: + register_general_batching(self.primitive) + else: + batching.primitive_batchers[self.primitive] = batching_translation + + # jvp rule + if jvp_translation is not None: + ad.primitive_jvps[self.primitive] = jvp_translation + + # transpose rule + if transpose_translation is not None: + ad.primitive_transposes[self.primitive] = transpose_translation + + def _abstract_eval(self, *args, **kwargs): + if self.outs is None: + raise ValueError('"self.outs" must be defined, but got None.') + return self.outs + + def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None): + if outs is not None: + self.outs = tuple([_transform_to_shapedarray(o) for o in outs]) + ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array) + return self.primitive.bind(*ins) + + def def_abstract_eval(self, fun): + """Define the abstract evaluation function. + + Args: + fun: The abstract evaluation function. + """ + self.primitive.def_abstract_eval(fun) + + def def_batching_rule(self, fun): + """Define the batching rule. + + Args: + fun: The batching rule. + """ + batching.primitive_batchers[self.primitive] = fun + + def def_jvp_rule(self, fun): + """Define the JVP rule. + + Args: + fun: The JVP rule. + """ + ad.primitive_jvps[self.primitive] = fun + + def def_transpose_rule(self, fun): + """Define the transpose rule. + + Args: + fun: The transpose rule. + """ + ad.primitive_transposes[self.primitive] = fun + + def def_xla_translation(self, platform, fun): + """Define the XLA translation rule. + + Args: + platform: str. The computing platform. + fun: The XLA translation rule. + """ + xla.backend_specific_translations[platform][self.primitive] = fun + + def def_mlir_lowering(self, platform, fun): + """Define the MLIR lowering rule. + + Args: + platform: str. The computing platform. + fun: The lowering rule. + """ + mlir.register_lowering(self.primitive, fun, platform) + + +def _is_bp_array(a): + return isinstance(a, Array) + + +def _transform_to_array(a): + if isinstance(a, Array): + return a.value + elif isinstance(a, jax.Array): + return a + else: + return jax.numpy.asarray(a) + + +def _transform_to_shapedarray(a): + return jax.core.ShapedArray(a.shape, a.dtype) + diff --git a/brainpy/_src/math/op_registers/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py similarity index 68% rename from brainpy/_src/math/op_registers/numba_approach/__init__.py rename to brainpy/_src/math/op_register/numba_approach/__init__.py index ed960a738..76362215e 100644 --- a/brainpy/_src/math/op_registers/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -17,7 +17,6 @@ __all__ = [ 'CustomOpByNumba', - 'XLACustomOp', 'register_op_with_numba', 'compile_cpu_signature_with_numba', ] @@ -84,92 +83,6 @@ def __call__(self, *args, **kwargs): return res -class XLACustomOp(BrainPyObject): - """Creating a XLA custom call operator. - - Parameters - ---------- - name: str - The name of operator. - eval_shape: callable - The function to evaluate the shape and dtype of the output according to the input. - This function should receive the abstract information of inputs, and return the - abstract information of the outputs. For example: - - >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): - >>> return out1_info, out2_info - con_compute: callable - The function to make the concrete computation. This function receives inputs, - and returns outputs. For example: - - >>> def con_compute(inp1, inp2, inp3, ...): - >>> return out1, out2 - cpu_func: callable - The function defines the computation on CPU backend. Same as ``con_compute``. - gpu_func: callable - The function defines the computation on GPU backend. Currently, this function is not supported. - apply_cpu_func_to_gpu: bool - Whether allows to apply CPU function on GPU backend. If True, the GPU data will move to CPU, - and after calculation, the returned outputs on CPU backend will move to GPU. - - .. deprecated:: 2.2.4.1 - No longer supported. - """ - - def __init__( - self, - eval_shape: Callable = None, - con_compute: Callable = None, - cpu_func: Callable = None, - gpu_func: Callable = None, - apply_cpu_func_to_gpu: bool = None, - name: str = None, - batching_translation: Callable = None, - jvp_translation: Callable = None, - transpose_translation: Callable = None, - multiple_results: bool = True, - ): - super(XLACustomOp, self).__init__(name=name) - - if apply_cpu_func_to_gpu is not None: - warnings.warn('"apply_cpu_func_to_gpu" has been removed.', UserWarning) - - # abstract evaluation function - if eval_shape is None: - raise ValueError('Must provide "eval_shape" for abstract evaluation.') - - # cpu function - if con_compute is None: - if cpu_func is None: - raise ValueError('Must provide one of "cpu_func" or "con_compute".') - else: - cpu_func = con_compute - - # gpu function - if gpu_func is None: - gpu_func = None - - # register OP - self.op = register_op_with_numba( - self.name, - cpu_func=cpu_func, - gpu_func_translation=gpu_func, - out_shapes=eval_shape, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) - - def __call__(self, *args, **kwargs): - args = tree_map(lambda a: a.value if isinstance(a, Array) else a, - args, is_leaf=lambda a: isinstance(a, Array)) - kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a, - kwargs, is_leaf=lambda a: isinstance(a, Array)) - res = self.op.bind(*args, **kwargs) - return res - - def register_op_with_numba( op_name: str, cpu_func: Callable, diff --git a/brainpy/_src/math/op_registers/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py similarity index 100% rename from brainpy/_src/math/op_registers/numba_approach/cpu_translation.py rename to brainpy/_src/math/op_register/numba_approach/cpu_translation.py diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py new file mode 100644 index 000000000..73e96f2b0 --- /dev/null +++ b/brainpy/_src/math/op_register/numba_based.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- + +import ctypes +from functools import partial + +from jax.interpreters import xla +from jax.lib import xla_client +from numba import types, carray, cfunc + +__all__ = [ + 'register_numba_cpu_translation_rule', +] + +ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, # void* pointer + ctypes.c_char_p, # const char *name + ctypes.c_void_p, # PyCapsule_Destructor destructor +] +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + + +def _cpu_signature( + kernel, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + debug: bool = False +): + # kernel_key = str(id(kernel)) + # input_keys = [f'{dtype}({shape})' for dtype, shape in zip(input_dtypes, input_shapes)] + # output_keys = [f'{dtype}({shape})' for dtype, shape in zip(output_dtypes, output_shapes)] + # key = f'{kernel_key}-ins=[{", ".join(input_keys)}]-outs=[{", ".join(output_keys)}]' + # if key not in __cache: + + code_scope = dict( + func_to_call=kernel, + input_shapes=input_shapes, + input_dtypes=input_dtypes, + output_shapes=output_shapes, + output_dtypes=output_dtypes, + carray=carray, + ) + + # inputs + args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' + for i in range(len(input_shapes))] + args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' + for i in range(len(output_shapes))] + args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))] + + # function body + code_string = ''' +def xla_cpu_custom_call_target(output_ptrs, input_ptrs): + {args_in} + {args_out} + func_to_call({args_call}) + '''.format(args_in="\n ".join(args_in), + args_out="\n ".join(args_out), + args_call=", ".join(args_call)) + if debug: print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + + new_f = code_scope['xla_cpu_custom_call_target'] + xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), + types.CPointer(types.voidptr)))(new_f) + target_name = xla_c_rule.native_name.encode("ascii") + capsule = ctypes.pythonapi.PyCapsule_New( + xla_c_rule.address, # A CFFI pointer to a function + b"xla._CUSTOM_CALL_TARGET", # A binary string + None # PyCapsule object run at destruction + ) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + + # else: + # target_name = __cache[key] + return target_name + + +def _numba_cpu_translation_rule(prim, kernel, debug: bool, c, *ins): + outs = prim.abstract_eval()[0] + + # output information + output_shapes = tuple(out.shape for out in outs) + output_dtypes = tuple(out.dtype for out in outs) + output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) + output_infos = [xla_client.Shape.array_shape(*arg) for arg in zip(output_dtypes, output_shapes, output_layouts)] + output_infos = xla_client.Shape.tuple_shape(output_infos) + + # input information + input_layouts = tuple(c.get_shape(arg) for arg in ins) + input_dtypes = tuple(inp.element_type() for inp in input_layouts) + input_shapes = tuple(inp.dimensions() for inp in input_layouts) + + # compiling + target_name = _cpu_signature(kernel, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + debug=debug) + + # call + return xla_client.ops.CustomCallWithLayout( + c, + target_name, + operands=tuple(ins), + operand_shapes_with_layout=input_layouts, + shape_with_layout=output_infos, + ) + + +def register_numba_cpu_translation_rule(primitive, cpu_kernel, debug=False): + xla.backend_specific_translations['cpu'][primitive] = partial(_numba_cpu_translation_rule, + primitive, cpu_kernel, debug) diff --git a/brainpy/_src/math/op_register/taichi_based.py b/brainpy/_src/math/op_register/taichi_based.py new file mode 100644 index 000000000..c30d9f9b9 --- /dev/null +++ b/brainpy/_src/math/op_register/taichi_based.py @@ -0,0 +1,9 @@ + + +def register_taichi_cpu_translation_rule(primitive, cpu_kernel): + pass + + +def register_taichi_gpu_translation_rule(primitive, cpu_kernel): + pass + diff --git a/brainpy/_src/math/op_registers/tests/test_ei_net.py b/brainpy/_src/math/op_register/tests/test_ei_net.py similarity index 100% rename from brainpy/_src/math/op_registers/tests/test_ei_net.py rename to brainpy/_src/math/op_register/tests/test_ei_net.py diff --git a/brainpy/_src/math/op_registers/utils.py b/brainpy/_src/math/op_register/utils.py similarity index 100% rename from brainpy/_src/math/op_registers/utils.py rename to brainpy/_src/math/op_register/utils.py diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index fb1ce7039..42e885e6e 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -12,8 +12,8 @@ from jax.lib import xla_client from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba, - register_general_batching) +from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, + register_general_batching) from brainpy.errors import GPUOperatorNotFound try: diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/_bsr_mv.py index 331858c3b..7aa8f6e82 100644 --- a/brainpy/_src/math/sparse/_bsr_mv.py +++ b/brainpy/_src/math/sparse/_bsr_mv.py @@ -9,8 +9,8 @@ from jax.lib import xla_client from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba, - register_general_batching) +from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, + register_general_batching) from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/sparse/_coo_mv.py b/brainpy/_src/math/sparse/_coo_mv.py index 85004c851..2885d9463 100644 --- a/brainpy/_src/math/sparse/_coo_mv.py +++ b/brainpy/_src/math/sparse/_coo_mv.py @@ -12,7 +12,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_registers import register_general_batching +from brainpy._src.math.op_register import register_general_batching __all__ = [ 'coomv', diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 9a37f0902..fd09892c6 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -15,8 +15,8 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_registers import (compile_cpu_signature_with_numba, - register_general_batching) +from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, + register_general_batching) from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/_utils.py index 68373cc03..a1dc9190e 100644 --- a/brainpy/_src/math/sparse/_utils.py +++ b/brainpy/_src/math/sparse/_utils.py @@ -10,7 +10,7 @@ from jaxlib import gpu_sparse from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_registers import register_general_batching +from brainpy._src.math.op_register import register_general_batching __all__ = [ 'coo_to_csr', diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index 7fb7df73f..b30ce4414 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 -*- -from brainpy._src.math.op_registers import ( +from brainpy._src.math.op_register import ( CustomOpByNumba, - XLACustomOp, compile_cpu_signature_with_numba, ) +from brainpy._src.math.op_register.base import XLACustomOp + + From 1e857c702e0a9b00da720f67967b8c44c30b9171 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 30 Oct 2023 15:14:44 +0800 Subject: [PATCH 4/4] fix bugs --- README.md | 2 +- brainpy/_src/dnn/linear.py | 72 +++++---- brainpy/_src/dyn/others/input.py | 8 +- brainpy/_src/dyn/projections/plasticity.py | 36 ++--- .../_src/dyn/projections/tests/test_STDP.py | 6 +- .../_src/dyn/rates/tests/test_reservoir.py | 2 +- .../math/op_register/tests/test_ei_net.py | 77 ---------- brainpy/_src/math/tests/test_op_register.py | 141 ------------------ brainpy/_src/mixin.py | 6 + examples/dynamics_simulation/stdp.py | 64 ++++++++ 10 files changed, 143 insertions(+), 271 deletions(-) delete mode 100644 brainpy/_src/math/op_register/tests/test_ei_net.py delete mode 100644 brainpy/_src/math/tests/test_op_register.py create mode 100644 examples/dynamics_simulation/stdp.py diff --git a/README.md b/README.md index fa553633f..716dbd900 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ We provide a Binder environment for BrainPy. You can use the following button to - **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming. - **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation. - **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling. -- [第一届神经计算建模与编程培训班 (BrainPy First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course) +- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course) ## Citing diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 2301bab7a..314ffb19c 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Union, Callable +import numba import jax import numpy as np import jax.numpy as jnp @@ -227,6 +228,45 @@ def update(self, x): return x +def event_mm(pre_spike, post_inc, weight, w_min, w_max): + return weight + + +@numba.njit +def event_mm_imp(outs, ins): + pre_spike, post_inc, weight, w_min, w_max = ins + w_min = w_min[()] + w_max = w_max[()] + outs = outs + outs.fill(weight) + for i in range(pre_spike.shape[0]): + if pre_spike[i]: + outs[i] = np.clip(outs[i] + post_inc, w_min, w_max) + + +event_left_mm = bm.CustomOpByNumba(event_mm, event_mm_imp, multiple_results=False) + + +def event_mm2(post_spike, pre_inc, weight, w_min, w_max): + return weight + + +@numba.njit +def event_mm_imp2(outs, ins): + post_spike, pre_inc, weight, w_min, w_max = ins + w_min = w_min[()] + w_max = w_max[()] + outs = outs + outs.fill(weight) + for j in range(post_spike.shape[0]): + if post_spike[j]: + outs[:, j] = np.clip(outs[:, j] + pre_inc, w_min, w_max) + + +event_right_mm = bm.CustomOpByNumba(event_mm2, event_mm_imp2, multiple_results=False) + + + class AllToAll(Layer, SupportSTDP): """Synaptic matrix multiplication with All2All connections. @@ -289,20 +329,15 @@ def update(self, pre_val): post_val = pre_val @ self.weight return post_val - def update_STDP(self, dW, constraints=None): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - if self.weight.shape != dW.shape: - raise ValueError(f'The shape of delta_weight {dW.shape} ' - f'should be the same as the shape of weight {self.weight.shape}.') + def stdp_update_on_pre(self, pre_spike, trace, w_min=None, w_max=None): if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight += dW - if constraints is not None: - self.weight.value = constraints(self.weight) + self.weight.value = event_left_mm(pre_spike, trace, self.weight, w_min, w_max) + def stdp_update_on_post(self, post_spike, trace, w_min=None, w_max=None): + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + self.weight.value = event_right_mm(post_spike, trace, self.weight, w_min, w_max) class OneToOne(Layer, SupportSTDP): @@ -338,21 +373,6 @@ def __init__( def update(self, pre_val): return pre_val * self.weight - def update_STDP(self, dW, constraints=None): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - dW = dW.sum(axis=0) - if self.weight.shape != dW.shape: - raise ValueError(f'The shape of delta_weight {dW.shape} ' - f'should be the same as the shape of weight {self.weight.shape}.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight += dW - if constraints is not None: - self.weight.value = constraints(self.weight) - class MaskedLinear(Layer, SupportSTDP): r"""Synaptic matrix multiplication with masked dense computation. diff --git a/brainpy/_src/dyn/others/input.py b/brainpy/_src/dyn/others/input.py index 92a2390b4..60632dc9f 100644 --- a/brainpy/_src/dyn/others/input.py +++ b/brainpy/_src/dyn/others/input.py @@ -228,9 +228,5 @@ def update(self): def return_info(self): return self.spike - def reset_state(self, batch_size=None, **kwargs): - self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type), - self.varshape, - batch_size, - axis_names=self.sharding, - batch_axis_name=bm.sharding.BATCH_AXIS) + def reset_state(self, batch_or_mode=None, **kwargs): + self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode) diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 5894a1452..c51332e44 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -4,7 +4,6 @@ from brainpy._src.delay import register_delay_by_return from brainpy._src.dyn.synapses.abstract_models import Expon from brainpy._src.dynsys import DynamicalSystem, Projection -from brainpy._src.initialize import parameter from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP) from brainpy.types import ArrayType @@ -111,7 +110,8 @@ def run(i, I_pre, I_post): A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. W_max: float. The maximum weight. - pre: DynamicalSystem. The pre-synaptic neuron group. + W_min: float. The minimum weight. + pre: DynamicalSystem. The pre-synaptic neuron group. delay: int, float. The pre spike delay length. (ms) syn: DynamicalSystem. The synapse model. comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. @@ -135,6 +135,7 @@ def __init__( A1: Union[float, ArrayType, Callable] = 0.96, A2: Union[float, ArrayType, Callable] = 0.53, W_max: Optional[float] = None, + W_min: Optional[float] = None, # others out_label: Optional[str] = None, name: Optional[str] = None, @@ -144,21 +145,21 @@ def __init__( # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) check.is_instance(post, DynamicalSystem) self.pre_num = pre.num self.post_num = post.num self.comm = comm - self.syn = syn + self._is_align_post = issubclass(syn.cls, AlignPost) # delay initialization delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # synapse and output initialization - if issubclass(syn.cls, AlignPost): + if self._is_align_post: syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) else: @@ -171,24 +172,27 @@ def __init__( self.refs['delay'] = delay_cls self.refs['syn'] = syn_cls # invisible to ``self.node()`` self.refs['out'] = out_cls # invisible to ``self.node()`` + self.refs['comm'] = comm # tracing pre-synaptic spikes using Exponential model self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) + # tracing post-synaptic spikes using Exponential model self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) # synapse parameters self.W_max = W_max - self.tau_s = parameter(tau_s, sizes=self.pre_num) - self.tau_t = parameter(tau_t, sizes=self.post_num) - self.A1 = parameter(A1, sizes=self.pre_num) - self.A2 = parameter(A2, sizes=self.post_num) + self.W_min = W_min + self.tau_s = tau_s + self.tau_t = tau_t + self.A1 = A1 + self.A2 = A2 def update(self): # pre-synaptic spikes pre_spike = self.refs['delay'].at(self.name) # spike # pre-synaptic variables - if issubclass(self.syn.cls, AlignPost): + if self._is_align_post: # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance x = pre_spike else: @@ -201,19 +205,17 @@ def update(self): post_spike = self.refs['post'].spike # weight updates - Apre = self.refs['pre_trace'].g Apost = self.refs['post_trace'].g - delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike) - self.comm.update_STDP(delta_w, constraints=self._weight_clip) + self.comm.stdp_update_on_pre(pre_spike, -Apost * self.A2, self.W_min, self.W_max) + Apre = self.refs['pre_trace'].g + self.comm.stdp_update_on_post(post_spike, Apre * self.A1, self.W_min, self.W_max) - # currents + # synaptic currents current = self.comm(x) - if issubclass(self.syn.cls, AlignPost): + if self._is_align_post: self.refs['syn'].add_current(current) # synapse post current else: self.refs['out'].bind_cond(current) # align pre return current - def _weight_clip(self, w): - return w if self.W_max is None else bm.minimum(w, self.W_max) diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index e33644f26..001afc02e 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -21,8 +21,8 @@ def __init__(self, num_pre, num_post): pre=self.pre, delay=1., # comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - # weight=bp.init.Uniform(-0.1, 0.1)), - comm=bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(-0.1, 0.1)), + # weight=bp.init.Uniform(0., 0.1)), + comm=bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)), syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), out=bp.dyn.COBA.desc(E=0.), post=self.post, @@ -30,6 +30,8 @@ def __init__(self, num_pre, num_post): tau_t=33.7, A1=0.96, A2=0.53, + W_min=0., + W_max=1. ) def update(self, I_pre, I_post): diff --git a/brainpy/_src/dyn/rates/tests/test_reservoir.py b/brainpy/_src/dyn/rates/tests/test_reservoir.py index 371c7aa89..34d00c909 100644 --- a/brainpy/_src/dyn/rates/tests/test_reservoir.py +++ b/brainpy/_src/dyn/rates/tests/test_reservoir.py @@ -15,7 +15,7 @@ class Test_Reservoir(parameterized.TestCase): def test_Reservoir(self, mode): bm.random.seed() input = bm.random.randn(10, 3) - layer = bp.syn.Reservoir(input_shape=3, + layer = bp.dyn.Reservoir(input_shape=3, num_out=5, mode=mode) if mode in [bm.NonBatchingMode()]: diff --git a/brainpy/_src/math/op_register/tests/test_ei_net.py b/brainpy/_src/math/op_register/tests/test_ei_net.py deleted file mode 100644 index 28d106cb2..000000000 --- a/brainpy/_src/math/op_register/tests/test_ei_net.py +++ /dev/null @@ -1,77 +0,0 @@ -import brainpy.math as bm -import brainpy as bp -from jax.core import ShapedArray - - -def abs_eval(events, indices, indptr, *, weight, post_num): - return [ShapedArray((post_num,), bm.float32), ] - - -def con_compute(outs, ins): - post_val, = outs - post_val.fill(0) - events, indices, indptr, weight, _ = ins - weight = weight[()] - for i in range(events.size): - if events[i]: - for j in range(indptr[i], indptr[i + 1]): - index = indices[j] - post_val[index] += weight - - -event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute) - - -class ExponentialV2(bp.synapses.TwoEndConn): - """Exponential synapse model using customized operator written in C++.""" - - def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): - super(ExponentialV2, self).__init__(pre=pre, post=post, conn=conn) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - - # parameters - self.E = E - self.tau = tau - self.delay = delay - self.g_max = g_max - self.pre2post = self.conn.require('pre2post') - - # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - - # function - self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto') - - def update(self): - self.g.value = self.integral(self.g, bp.share['t']) - self.g += event_sum(self.pre.spike, - self.pre2post[0], - self.pre2post[1], - weight=self.g_max, - post_num=self.post.num)[0] - self.post.input += self.g * (self.E - self.post.V) - - -class EINet(bp.DynSysGroup): - def __init__(self, scale): - super().__init__() - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.), method='exp_auto') - self.E = bp.neurons.LIF(int(3200 * scale), **pars) - self.I = bp.neurons.LIF(int(800 * scale), **pars) - - # synapses - self.E2E = ExponentialV2(self.E, self.E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - self.E2I = ExponentialV2(self.E, self.I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - self.I2E = ExponentialV2(self.I, self.E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - self.I2I = ExponentialV2(self.I, self.I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - - -def test1(): - bm.set_platform('cpu') - net2 = EINet(scale=0.1) - runner = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) - r = runner.predict(100., eval_time=True) - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/tests/test_op_register.py b/brainpy/_src/math/tests/test_op_register.py deleted file mode 100644 index 6917202ad..000000000 --- a/brainpy/_src/math/tests/test_op_register.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- - -import unittest - -import jax -import matplotlib.pyplot as plt - -import brainpy as bp -import brainpy.math as bm - - -bm.random.seed() -bm.set_platform('cpu') - - -def abs_eval(events, indices, indptr, post_val, values): - return [post_val] - - -def event_sum_op(outs, ins): - events, indices, indptr, post, values = ins - v = values[()] - outs, = outs - outs.fill(0) - for i in range(len(events)): - if events[i]: - for j in range(indptr[i], indptr[i + 1]): - index = indices[j] - outs[index] += v - - -event_sum2 = bm.XLACustomOp(name='event_sum2', cpu_func=event_sum_op, eval_shape=abs_eval) - - -class ExponentialSyn(bp.TwoEndConn): - def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., - method='exp_auto'): - super(ExponentialSyn, self).__init__(pre=pre, post=post, conn=conn) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - - # parameters - self.E = E - self.tau = tau - self.delay = delay - self.g_max = g_max - self.pre2post = self.conn.require('pre2post') - - # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - - # function - self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) - - def update(self, tdi): - self.g.value = self.integral(self.g, tdi['t'], dt=tdi['dt']) - self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max) - self.post.input += self.g * (self.E - self.post.V) - - -class ExponentialSyn3(bp.TwoEndConn): - def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., - method='exp_auto'): - super(ExponentialSyn3, self).__init__(pre=pre, post=post, conn=conn) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - - # parameters - self.E = E - self.tau = tau - self.delay = delay - self.g_max = g_max - self.pre2post = self.conn.require('pre2post') - - # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - - # function - self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) - - def update(self, tdi): - self.g.value = self.integral(self.g, tdi['t'], tdi['dt']) - # Customized operator - # ------------------------------------------------------------------------------------------------------------ - post_val = bm.zeros(self.post.num) - r = event_sum2(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max) - self.g += r[0] - # ------------------------------------------------------------------------------------------------------------ - self.post.input += self.g * (self.E - self.post.V) - - -class EINet(bp.Network): - def __init__(self, syn_class, scale=1.0, method='exp_auto', ): - super(EINet, self).__init__() - - # network size - num_exc = int(3200 * scale) - num_inh = int(800 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = syn_class(self.E, self.E, bp.conn.FixedProb(0.02), E=0., g_max=we, tau=5., method=method) - self.E2I = syn_class(self.E, self.I, bp.conn.FixedProb(0.02), E=0., g_max=we, tau=5., method=method) - self.I2E = syn_class(self.I, self.E, bp.conn.FixedProb(0.02), E=-80., g_max=wi, tau=10., method=method) - self.I2I = syn_class(self.I, self.I, bp.conn.FixedProb(0.02), E=-80., g_max=wi, tau=10., method=method) - - -class TestOpRegister(unittest.TestCase): - def test_op(self): - bm.random.seed(123) - fig, gs = bp.visualize.get_figure(1, 2, 4, 5) - - net = EINet(ExponentialSyn, scale=0.1, method='euler') - runner = bp.DSRunner( - net, - inputs=[(net.E.input, 20.), (net.I.input, 20.)], - monitors={'E.spike': net.E.spike}, - ) - t, _ = runner.run(100., eval_time=True) - print(t) - ax = fig.add_subplot(gs[0, 0]) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax) - - net3 = EINet(ExponentialSyn3, scale=0.1, method='euler') - runner3 = bp.DSRunner( - net3, - inputs=[(net3.E.input, 20.), (net3.I.input, 20.)], - monitors={'E.spike': net3.E.spike}, - ) - t, _ = runner3.run(100., eval_time=True) - print(t) - plt.close() - bm.clear_buffer_memory() diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 177b60aa6..f356f44b3 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -490,6 +490,12 @@ def update_STDP( ): raise NotImplementedError + def stdp_update_on_pre(self, pre_spike, trace, *args, **kwargs): + raise NotImplementedError + + def stdp_update_on_post(self, post_spike, trace, *args, **kwargs): + raise NotImplementedError + T = TypeVar('T') diff --git a/examples/dynamics_simulation/stdp.py b/examples/dynamics_simulation/stdp.py new file mode 100644 index 000000000..edaf90e44 --- /dev/null +++ b/examples/dynamics_simulation/stdp.py @@ -0,0 +1,64 @@ +""" +Reproduce the following STDP paper: + +- Song, S., Miller, K. & Abbott, L. Competitive Hebbian learning through spike-timing-dependent + synaptic plasticity. Nat Neurosci 3, 919–926 (2000). https://doi.org/10.1038/78829 +""" + +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + + +class STDPNet(bp.DynSysGroup): + def __init__(self, num_poisson, num_lif=1, g_max=0.01): + super().__init__() + + self.g_max = g_max + + # neuron groups + self.noise = bp.dyn.PoissonGroup(num_poisson, freqs=15.) + self.group = bp.dyn.Lif(num_lif, V_reset=-60., V_rest=-74, V_th=-54, tau=10., + V_initializer=bp.init.Normal(-60., 1.)) + + # synapses + syn = bp.dyn.Expon.desc(num_lif, tau=5.) + out = bp.dyn.COBA.desc(E=0.) + comm = bp.dnn.AllToAll(num_poisson, num_lif, bp.init.Uniform(0., g_max)) + self.syn = bp.dyn.STDP_Song2000(self.noise, None, syn, comm, out, self.group, + tau_s=20, tau_t=20, W_max=g_max, W_min=0., + A1=0.01 * g_max, A2=0.0105 * g_max) + + def update(self, *args, **kwargs): + self.noise() + self.syn() + self.group() + return self.syn.comm.weight.flatten()[:10] + + +def run_model(): + net = STDPNet(1000, 1) + indices = np.arange(int(100.0e3 / bm.dt)) # 100 s + ws = bm.for_loop(net.step_run, indices, progress_bar=True) + weight = bm.as_numpy(net.syn.comm.weight.flatten()) + + fig, gs = bp.visualize.get_figure(3, 1, 3, 10) + fig.add_subplot(gs[0, 0]) + plt.plot(weight / net.g_max, '.k') + plt.xlabel('Weight / gmax') + + fig.add_subplot(gs[1, 0]) + plt.hist(weight / net.g_max, 20) + plt.xlabel('Weight / gmax') + + fig.add_subplot(gs[2, 0]) + plt.plot(indices * bm.dt, bm.as_numpy(ws) / net.g_max) + plt.xlabel('Time (s)') + plt.ylabel('Weight / gmax') + plt.show() + + +if __name__ == '__main__': + run_model()