From 55c69bb3659f555263096b3a118fccce73600fd7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 17 Feb 2024 20:32:11 +0800 Subject: [PATCH 01/16] try to remove hard dependency with taichi and numba --- brainpy/_src/math/defaults.py | 17 +- brainpy/_src/math/environment.py | 21 +- brainpy/_src/math/event/__init__.py | 1 - brainpy/_src/math/event/_csr_matvec.py | 1226 +++----- brainpy/_src/math/event/_info_collection.py | 198 -- .../tests/event_info_VS_jax_operators.py | 275 -- .../_src/math/event/tests/test_event_csrmv.py | 7 + .../math/event/tests/test_event_csrmv_old.py | 8 +- brainpy/_src/math/event/tests/test_info.py | 62 - .../_src/math/event/tests/test_info_gpu.py | 14 - brainpy/_src/math/index_tricks.py | 305 -- brainpy/_src/math/jitconn/_event_matvec.py | 2621 +++++++---------- brainpy/_src/math/jitconn/_matvec.py | 2065 ++++--------- .../math/jitconn/tests/test_event_matvec.py | 6 + .../_src/math/jitconn/tests/test_matvec.py | 5 + brainpy/_src/math/op_register/numba_based.py | 1 - brainpy/_src/math/sparse/__init__.py | 4 +- brainpy/_src/math/sparse/tests/test_csrmv.py | 6 +- .../_src/math/sparse/tests/test_csrmv_old.py | 6 +- brainpy/_src/math/tifunc.py | 538 ++-- brainpy/errors.py | 15 +- brainpy/math/event.py | 1 - brainpy/math/sparse.py | 1 - requirements.txt | 2 - setup.py | 11 +- 25 files changed, 2318 insertions(+), 5098 deletions(-) delete mode 100644 brainpy/_src/math/event/_info_collection.py delete mode 100644 brainpy/_src/math/event/tests/event_info_VS_jax_operators.py delete mode 100644 brainpy/_src/math/event/tests/test_info.py delete mode 100644 brainpy/_src/math/event/tests/test_info_gpu.py delete mode 100644 brainpy/_src/math/index_tricks.py diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 19aca92cf..dae0f1bcd 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -24,15 +24,20 @@ # '''Default integer data type.''' int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -# '''Default integer data type in Taichi.''' -ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - # '''Default float data type.''' float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 -# '''Default float data type in Taichi.''' -ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 - # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 + +if ti is not None: + # '''Default integer data type in Taichi.''' + ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 + + # '''Default float data type in Taichi.''' + ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 + +else: + ti_int = None + ti_float = None \ No newline at end of file diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 1c8b98a3b..757c19b8d 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -416,13 +416,16 @@ def set_float(dtype: type): """ if dtype in [jnp.float16, 'float16', 'f16']: defaults.__dict__['float_'] = jnp.float16 - defaults.__dict__['ti_float'] = ti.float16 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float16 elif dtype in [jnp.float32, 'float32', 'f32']: defaults.__dict__['float_'] = jnp.float32 - defaults.__dict__['ti_float'] = ti.float32 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float32 elif dtype in [jnp.float64, 'float64', 'f64']: defaults.__dict__['float_'] = jnp.float64 - defaults.__dict__['ti_float'] = ti.float64 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float64 else: raise NotImplementedError @@ -448,16 +451,20 @@ def set_int(dtype: type): """ if dtype in [jnp.int8, 'int8', 'i8']: defaults.__dict__['int_'] = jnp.int8 - defaults.__dict__['ti_int'] = ti.int8 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int8 elif dtype in [jnp.int16, 'int16', 'i16']: defaults.__dict__['int_'] = jnp.int16 - defaults.__dict__['ti_int'] = ti.int16 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int16 elif dtype in [jnp.int32, 'int32', 'i32']: defaults.__dict__['int_'] = jnp.int32 - defaults.__dict__['ti_int'] = ti.int32 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int32 elif dtype in [jnp.int64, 'int64', 'i64']: defaults.__dict__['int_'] = jnp.int64 - defaults.__dict__['ti_int'] = ti.int64 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int64 else: raise NotImplementedError diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 631129558..e61dc10cf 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,4 +1,3 @@ -from ._info_collection import * from ._csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 6e03be463..f4f23fa93 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -10,27 +10,19 @@ """ -from functools import partial from typing import Union, Tuple import jax import jax.numpy as jnp -import numba import numpy as np -from jax.core import ShapedArray, Primitive -from jax.interpreters import ad, xla -from jax.lib import xla_client +from jax.interpreters import ad -from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) -from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv +from brainpy._src.math.op_register import XLACustomOp from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound +from brainpy.errors import PackageMissingError __all__ = [ 'csrmv' @@ -81,535 +73,6 @@ def csrmv( return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) -### BRAINPYLIB ### - -def csrmv_brainpylib( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indices should be a 1D vector with int32 or int64 type.') - if indptr.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indptr should be a 1D vector with int32 or int64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # computing - return event_csr_matvec_p.bind(data, indices, indptr, events, shape=shape, transpose=transpose) - - -# ---------------------------------------------------------- -# event csr matvec -# ---------------------------------------------------------- - -# operator for `event_csr_matvec` batching rule -# -------- - -def _batch_event_csr_matvec_abstract( - values, indices, indptr, events, *, batch_size, shape, transpose=False -): - return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0])) - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_transpose_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, _ = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - values_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - value = values[values_bi, 0] - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += value - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += values[value_bi, j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, transpose = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - value = values[value_bi, 0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += value - res_val[bi, row_i] = r - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += values[value_bi, j] - res_val[bi, row_i] = r - - -def _batch_event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - inputs = (values, indices, indptr, events) - description = dict(batch_size=batch_size, shape=shape, transpose=transpose) - if transpose: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_transpose_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - else: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, - name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _batch_event_csr_matvec_gpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - pass - - -def _batch_event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return event_csr_matvec_batching_p.bind(values_dot, indices, indptr, events, - batch_size=batch_size, shape=shape, transpose=transpose) - - -def _batch_csr_matvec(values, indices, indptr, vectors, *, shape, transpose): - f = jax.vmap(partial(normal_csrmv, shape=shape, transpose=transpose), - in_axes=(0 if values.shape[0] > 1 else None, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if vectors.shape[0] > 1 else None)) - return f(values if values.shape[0] > 1 else values[0], - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - vectors if vectors.shape[0] > 1 else vectors[0]) - - -def _batch_event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return _batch_csr_matvec(values, indices, indptr, events_dot, - shape=shape, transpose=transpose) - - -def _batch_event_csr_matvec_transpose(ct, values, indices, indptr, events, *, - batch_size, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(events): - ct_events = ( - ad.Zero(events.aval) if type(ct) is ad.Zero else - _batch_csr_matvec(ct, indices, indptr, values, - shape=shape, transpose=not transpose) - ) - return values, indices, indptr, ct_events - else: - if values.aval.shape[1] == 1: # scalar - temp = event_csr_matvec_batching_p.bind(jnp.ones((1, 1)), indices, indptr, events, - batch_size=batch_size, shape=shape, - transpose=transpose) - ct_values = jax.vmap(jnp.inner)(ct, temp) - else: # heterogeneous values - if type(ct) is ad.Zero: - ct_values = ad.Zero(values.aval) - else: - - def _f(ct, indices, indptr, events, *, transpose): - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values - - f = jax.vmap(partial(_f, transpose=transpose), - in_axes=(0, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if events.shape[0] > 1 else None)) - ct_values = f(ct, - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - events if events.shape[0] > 1 else events[0]) - return ct_values, indices, indptr, events - - -event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching') -event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract) -event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation -ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values, - None, None, _batch_event_csr_matvec_jvp_events) -ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose - - -# operator for `event_csr_matvec` # -# ------------------------------- # - - -def _event_csr_matvec_abstract(values, indices, indptr, events, *, shape, transpose=False): - return ShapedArray(dtype=values.dtype, shape=(shape[1] if transpose else shape[0],)) - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values - res_val[row_i] = r - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values - res_val[row_i] = r - - -def _event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, shape, transpose): - inputs = (values, indices, indptr, events) - event_type = c.get_shape(events) - description = dict(shape=shape, transpose=transpose) - if transpose: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_transpose_numba_imp1_bool - else: - imp = _event_csr_matvec_transpose_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - else: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_numba_imp1_bool - else: - imp = _event_csr_matvec_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _event_csr_matvec_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_csr_matvec_p.name) - - # shape checking - data_shape = c.get_shape(data) - indices_shape = c.get_shape(indices) - indptr_shape = c.get_shape(indptr) - vec_shape = c.get_shape(vector) - if data_shape.element_type() == jnp.float32: - ftype = b'_float' - elif data_shape.element_type() == jnp.float64: - ftype = b'_double' - else: - raise ValueError - assert indices_shape.element_type() == indptr_shape.element_type() - if indices_shape.element_type() == jnp.int32: - itype = b'_int' - elif indices_shape.element_type() == jnp.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'_homo' if data_shape.dimensions() == (1,) else b'_heter' - tran_type = b'_transpose' if transpose else b'' - if vec_shape.element_type() == jnp.bool_: - vec_type = b'_bool' - else: - assert vec_shape.element_type() == data_shape.element_type() - vec_type = b'' - - # opaque - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - - # call - return xla_client.ops.CustomCallWithLayout( - c, - b'event_csrmv' + data_name + ftype + itype + vec_type + tran_type, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), - (shape[1] if transpose else shape[0],), - (0,)), - opaque=opaque, - ) - - -def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose): - batch_size = 0 - args_processed = [] - for arg, axis in zip(args, axes): - if axis is None: - arg = jnp.expand_dims(jnp.atleast_1d(arg), 0) - else: - batch_size = arg.shape[axis] - if axis > 0: - arg = jnp.moveaxis(arg, axis, 0) - args_processed.append(arg) - - r = event_csr_matvec_batching_p.bind(*args_processed, - batch_size=batch_size, - shape=shape, - transpose=transpose) - return r, 0 - - -def _event_csr_matvec_jvp_values_brainpylib(values_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) - - -def _event_csr_matvec_jvp_events_brainpylib(events_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values, indices, indptr, events_dot, shape=shape, transpose=transpose) - - -def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv(values, indices, indptr, ct, shape=shape, transpose=not transpose) - return values, indices, indptr, (ad.Zero(events) if type(ct) is ad.Zero else ct_events) - else: - if type(ct) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = csrmv_brainpylib(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) - ct_values = jnp.inner(ct, ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values, indices, indptr, events - - -event_csr_matvec_p = Primitive('event_csr_matvec') -event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract) -event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation -# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation -ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, - _event_csr_matvec_jvp_events_brainpylib) -ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib -register_general_batching(event_csr_matvec_p) - - -# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule - - -### TAICHI ### - def csrmv_taichi( data: Union[float, jax.Array], indices: jax.Array, @@ -691,298 +154,6 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] -# ------------- -# CPU operators -# ------------- - -# 1. The benchmarking shows that the performance of the following transpose -# kernels is maximized when using serialized mode -# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable -# arguments, we have to define each kernel separately when the -# non-differentiable/non-jittable arguments are different. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - -# 1. GPU kernels are different from the CPU ones, since the GPU kernels need -# to use warp-level parallelism to achieve the best performance. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -# TODO -# It is important to note that the following warp-based kernels -# should be improved, since the atomic_add for each thread is not -# very efficient. Instead, the warp-level reduction primitive -# should be used. -# see ``warp_reduce_sum()`` function in tifunc.py. -# However, currently Taichi does not support general warp-level primitives. - - -@ti.kernel -def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - def raw_csrmv_taichi( data: Union[float, jax.Array], indices: jax.Array, @@ -992,6 +163,9 @@ def raw_csrmv_taichi( shape: Tuple[int, int], transpose: bool = False ): + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + if transpose: if events.dtype == jnp.bool_: if data.shape[0] == 1: @@ -1025,65 +199,361 @@ def raw_csrmv_taichi( shape=shape) -def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) +if ti is not None: + + # ------------- + # CPU operators + # ------------- + + # 1. The benchmarking shows that the performance of the following transpose + # kernels is maximized when using serialized mode + # 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable + # arguments, we have to define each kernel separately when the + # non-differentiable/non-jittable arguments are different. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += values[j] + out[row_i] = r + + + # ------------- + # GPU operators + # ------------- + + # 1. GPU kernels are different from the CPU ones, since the GPU kernels need + # to use warp-level parallelism to achieve the best performance. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + # TODO + # It is important to note that the following warp-based kernels + # should be improved, since the atomic_add for each thread is not + # very efficient. Instead, the warp-level reduction primitive + # should be used. + # see ``warp_reduce_sum()`` function in tifunc.py. + # However, currently Taichi does not support general warp-level primitives. + + @ti.kernel + def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive -def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) -def _event_csr_matvec_transpose_taichi( - ct, values, indices, indptr, events, *, outs, transpose, shape -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) + def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + + + def _event_csr_matvec_transpose_taichi( + ct, values, indices, indptr, events, *, outs, transpose, shape + ): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(events): + ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] + return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events + if type(ct[0]) is ad.Zero: + ct_values = ad.Zero(values) + else: + if values.aval.shape[0] == 1: # scalar + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = jnp.inner(ct[0], ct_values) + else: # heterogeneous values + row, col = csr_to_coo(indices, indptr) + ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] + return ct_values, indices, indptr, events -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) - prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) - return prim + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) + prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) + return prim -# transpose bool homo -_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) + # transpose bool homo + _event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, + _event_csr_matvec_transpose_bool_homo_gpu) -# transpose homo -_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) + # transpose homo + _event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, + _event_csr_matvec_transpose_homo_gpu) -# not transpose bool homo -_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) + # not transpose bool homo + _event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, + _event_csr_matvec_bool_homo_gpu) -# not transpose homo -_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) + # not transpose homo + _event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, + _event_csr_matvec_homo_gpu) -# transpose bool heter -_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) + # transpose bool heter + _event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, + _event_csr_matvec_transpose_bool_heter_gpu) -# transpose heter -_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) + # transpose heter + _event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, + _event_csr_matvec_transpose_heter_gpu) -# not transpose bool heter -_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) + # not transpose bool heter + _event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, + _event_csr_matvec_bool_heter_gpu) -# not transpose heter -_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) + # not transpose heter + _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, + _event_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py deleted file mode 100644 index 7bb043e3e..000000000 --- a/brainpy/_src/math/event/_info_collection.py +++ /dev/null @@ -1,198 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Tuple, Union - -import jax -import numba -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client - -from brainpy._src.dependency_check import import_brainpylib_gpu_ops -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy.errors import GPUOperatorNotFound - -ti = import_taichi() - -__all__ = [ - 'info' -] - - -def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]: - """Collect event information, including event indices, and event number. - - This function supports JAX transformations, including `jit()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - events: jax.Array - The events. - - Returns - ------- - res: tuple - A tuple with two elements, denoting the event indices and the event number. - """ - events = as_jax(events) - if events.ndim != 1: - raise TypeError('Only support 1D boolean vector.') - return event_info_p(events) - - -def _batch_event_info_abstract(events): - assert events.ndim == 2 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(events.shape[0],)) - return event_ids, event_num - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -@ti.kernel -def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2), - event_ids: ti.types.ndarray(ndim=2), - event_num: ti.types.ndarray(ndim=1)): - for i, j in ti.grouped(ti.ndrange(event_ids.shape)): - event_ids[i, j] = -1 - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -def _batch_event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - shape = arg.shape - arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2])) - event_ids, event_num = batch_event_info_p(arg) - return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])), - (0, 0)) - - -def _event_info_gpu_translation(c, events): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_info_p.name) - - e_shape = c.get_shape(events).dimensions() - e_type = c.get_shape(events).element_type() - if len(e_shape) == 1: - event_size = e_shape[0] - batch_size = 1 - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (event_size,), - (0,)) - else: - batch_size, event_size = e_shape - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size, event_size), - (1, 0)) - event_num_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size,), - (0,)) - opaque = gpu_ops.build_nonzero_descriptor(event_size, batch_size) - - if e_type == jnp.bool_: - type_name = b'_bool' - elif e_type == jnp.int32: - type_name = b'_int' - elif e_type == jnp.int64: - type_name = b'_long' - elif e_type == jnp.float32: - type_name = b'_float' - elif e_type == jnp.float64: - type_name = b'_double' - else: - raise ValueError - - return xla_client.ops.CustomCallWithLayout( - c, - b'nonzero' + type_name, - operands=(events,), - operand_shapes_with_layout=(c.get_shape(events),), - shape_with_layout=xla_client.Shape.tuple_shape((event_ids_shape, event_num_shape)), - opaque=opaque, - ) - - -batch_event_info_p = XLACustomOp( - name='batched_event_info', - cpu_kernel=_batch_event_info_taichi, - gpu_kernel=_batch_event_info_taichi, - outs=_batch_event_info_abstract, -) -batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule) - - -def _event_info_abstract(events, **kwargs): - assert events.ndim == 1 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(1,)) - return event_ids, event_num - - -# TODO: first parallel evaluate the sub-sections, then serially event the sub-results. -@numba.njit(fastmath=True) -def _event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -@ti.kernel -def _event_info_taichi(events: ti.types.ndarray(ndim=1), - event_ids: ti.types.ndarray(ndim=1), - event_num: ti.types.ndarray(ndim=1)): - for i in range(event_ids.shape[0]): - event_ids[i] = -1 - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -def _event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - return (batch_event_info_p(arg), (0, 0)) - - -event_info_p = XLACustomOp( - name='event_info', - cpu_kernel=_event_info_taichi, - gpu_kernel=_event_info_taichi, - outs=_event_info_abstract, - # gpu_func_translation=_event_info_gpu_translation, -) -event_info_p.def_batching_rule(_event_info_batching_rule) diff --git a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py b/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py deleted file mode 100644 index 74cc6b7f9..000000000 --- a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py +++ /dev/null @@ -1,275 +0,0 @@ -from time import time - -from jax import jit, vmap, numpy as jnp - -import brainpy.math as bm - - -def compare_argsort_and_sum(platform='cpu'): - """ - CPU - --- - - shape = (100, 10000) - brainpylib 0.1872694492340088 s - JAX argsort + sum 5.297466516494751 s - - shape = (100, 100000) - brainpylib 2.333505153656006 s - JAX argsort + sum 65.20281910896301 s - - shape = (1000, 10000) - brainpylib 2.0739688873291016 s - JAX argsort + sum 53.70602822303772 s - - shape = (10000, 1000) - brainpylib 1.7262670993804932 s - JAX argsort + sum 43.92174816131592 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.14670848846435547 s - JAX argsort + sum 1.001936435699463 s - - shape = (100, 1000000) - brainpylib 0.27660632133483887 s - JAX argsort + sum 16.390073776245117 s - - shape = (1000, 100000) - brainpylib 0.2619345188140869 s - JAX argsort + sum 9.715844869613647 s - - shape = (1000, 500000) - brainpylib 1.201209306716919 s - JAX argsort + sum 71.19761657714844 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: (jnp.argsort(events), jnp.sum(events)))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids2, event_num2 = jax_event_info(events) - assert jnp.allclose(event_num1, event_num2) - event_ids1.block_until_ready() - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, b = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort + sum {time() - t0} s') - - print() - - -def compare_argsort(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.19738531112670898 s - JAX argsort 5.301469087600708 s - - shape = (100, 100000) - brainpylib 2.3321938514709473 s - JAX argsort 65.13460850715637 s - - shape = (1000, 10000) - brainpylib 2.0956876277923584 s - JAX argsort 53.863110065460205 s - - shape = (10000, 1000) - brainpylib 1.7127799987792969 s - JAX argsort 44.05547475814819 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.1415419578552246 s - JAX argsort 0.9982438087463379 s - - shape = (100, 1000000) - brainpylib 0.3224947452545166 s - JAX argsort 16.504750967025757 s - - shape = (1000, 100000) - brainpylib 0.2781648635864258 s - JAX argsort 9.691488981246948 s - - shape = (1000, 500000) - brainpylib 1.2167487144470215 s - JAX argsort 71.68716263771057 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.argsort(events))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2 = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort {time() - t0} s') - - print() - - -def compare_where(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.20480966567993164 s - JAX where 0.7068588733673096 s - - shape = (100, 100000) - brainpylib 2.3373026847839355 s - JAX where 5.862265348434448 s - - shape = (1000, 10000) - brainpylib 2.105764865875244 s - JAX where 5.914586067199707 s - - shape = (10000, 1000) - brainpylib 1.724682331085205 s - JAX where 5.718563795089722 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.15492558479309082 s - JAX where 0.3146538734436035 s - - shape = (100, 1000000) - brainpylib 0.3290700912475586 s - JAX where 1.7064015865325928 s - - shape = (1000, 100000) - brainpylib 0.2895216941833496 s - JAX where 1.6910102367401123 s - - shape = (1000, 500000) - brainpylib 1.173649787902832 s - JAX where 7.868000268936157 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.where(events, size=events.shape[0]))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2, = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX where {time() - t0} s') - - print() - - -if __name__ == '__main__': - # compare_argsort_and_sum('cpu') - # compare_argsort_and_sum('gpu') - # compare_argsort('cpu') - compare_argsort('gpu') - # compare_where('cpu') - # compare_where('gpu') diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index e0f38490f..1641c9db9 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -4,11 +4,18 @@ from functools import partial import jax +import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py index 31a6527a2..fcb25a89c 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv_old.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv_old.py @@ -4,19 +4,13 @@ from functools import partial import jax -from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm -import platform -import pytest pytest.skip('Old implementation.', allow_module_level=True) -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') diff --git a/brainpy/_src/math/event/tests/test_info.py b/brainpy/_src/math/event/tests/test_info.py deleted file mode 100644 index c326b0f76..000000000 --- a/brainpy/_src/math/event/tests/test_info.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -import unittest - -import brainpy.math as bm -from jax import vmap - -import pytest - - -class Test_event_info(unittest.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_info, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - def _base_test(self, length): - print(f'{self._base_test.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(length)) < 0.1 - event_ids, event_num = bm.event.info(events) - self.assertTrue(jnp.allclose(jnp.sum(events, keepdims=True), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap(self, length): - print(f'{self._base_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(bm.event.info)(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap_vmap(self, length): - print(f'{self._base_vmap_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(vmap(bm.event.info))(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def test(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - - diff --git a/brainpy/_src/math/event/tests/test_info_gpu.py b/brainpy/_src/math/event/tests/test_info_gpu.py deleted file mode 100644 index 55bdd15cd..000000000 --- a/brainpy/_src/math/event/tests/test_info_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_info - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_info_GPU(test_info.Test_event_info): - def __init__(self, *args, **kwargs): - super(Test_event_info_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/index_tricks.py b/brainpy/_src/math/index_tricks.py deleted file mode 100644 index 6c71b4b06..000000000 --- a/brainpy/_src/math/index_tricks.py +++ /dev/null @@ -1,305 +0,0 @@ -# -*- coding: utf-8 -*- - -import abc - -from jax import core -from .compat_numpy import arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose -import numpy as np - -__all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] - - -def _make_1d_grid_from_slice(s: slice, op_name: str): - start = core.concrete_or_error(None, s.start, - f"slice start of jnp.{op_name}") or 0 - stop = core.concrete_or_error(None, s.stop, - f"slice stop of jnp.{op_name}") - step = core.concrete_or_error(None, s.step, - f"slice step of jnp.{op_name}") or 1 - if np.iscomplex(step): - newobj = linspace(start, stop, int(abs(step))) - else: - newobj = arange(start, stop, step) - - return newobj - - -class _IndexGrid(abc.ABC): - """Creates multi-dimensional grids of indices.""" - sparse: bool - op_name: str - - def __getitem__(self, key): - if isinstance(key, slice): - return _make_1d_grid_from_slice(key, op_name=self.op_name) - output = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key) - output = meshgrid(*output, indexing='ij', sparse=self.sparse) - return output if self.sparse else stack(output, 0) - - -class _Mgrid(_IndexGrid): - """Return dense multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``. - - See Also: - jnp.ogrid: open/sparse version of jnp.mgrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> import brainpy.math as bm - >>> bm.mgrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.mgrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create broadcasted grids of indices: - - >>> bm.mgrid[:2, :3] - DeviceArray([[[0, 0, 0], - [1, 1, 1]], - [[0, 1, 2], - [0, 1, 2]]], dtype=int32) - """ - sparse = False - op_name = "mgrid" - - -mgrid = _Mgrid() - - -class _Ogrid(_IndexGrid): - """Return open multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``. - - See Also: - jnp.mgrid: dense version of jnp.ogrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> bm.ogrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.ogrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create sparse grids of indices: - - >>> bm.ogrid[:2, :3] - [DeviceArray([[0], - [1]], dtype=int32), - DeviceArray([[0, 1, 2]], dtype=int32)] - """ - sparse = True - op_name = "ogrid" - - -ogrid = _Ogrid() - - -class _AxisConcat(abc.ABC): - """Concatenates slices, scalars and array-like objects along a given axis.""" - axis: int - ndmin: int - trans1d: int - op_name: str - - def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - - params = [self.axis, self.ndmin, self.trans1d, -1] - - if isinstance(key[0], str): - # split off the directive - directive, *key = key # pytype: disable=bad-unpacking - # check two special cases: matrix directives - if directive == "r": - params[-1] = 0 - elif directive == "c": - params[-1] = 1 - else: - vec = directive.split(",") - k = len(vec) - if k < 4: - vec += params[k:] - else: - # ignore everything after the first three comma-separated ints - vec = vec[:3] + params[-1] - try: - params = list(map(int, vec)) - except ValueError as err: - raise ValueError( - "could not understand directive {!r}".format(directive) - ) from err - - axis, ndmin, trans1d, matrix = params - - output = [] - for item in key: - if isinstance(item, slice): - newobj = _make_1d_grid_from_slice(item, op_name=self.op_name) - elif isinstance(item, str): - raise ValueError("string directive must be placed at the beginning") - else: - newobj = item - - newobj = array(newobj, copy=False, ndmin=ndmin) - - if trans1d != -1 and ndmin - np.ndim(item) > 0: - shape_obj = list(range(ndmin)) - # Calculate number of left shifts, with overflow protection by mod - num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin - shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) - - newobj = transpose(newobj, shape_obj) - - output.append(newobj) - - res = concatenate(tuple(output), axis=axis) - - if matrix != -1 and res.ndim == 1: - # insert 2nd dim at axis 0 or 1 - res = expand_dims(res, matrix) - - return res - - def __len__(self): - return 0 - - -class RClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the first axis. - - LAX-backend implementation of :obj:`numpy.r_`. - - See Also: - ``jnp.c_``: Concatenates slices, scalars and array-like objects along the last axis. - - Examples: - Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects: - - >>> bm.r_[-1:5:1, 0, 0, bm.array([1,2,3])] - DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32) - - An imaginary value for ``step`` will create a ``jnp.linspace`` object instead, - which includes the right endpoint: - - >>> bm.r_[-1:1:6j, 0, bm.array([1,2,3])] - DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005, - 0.6 , 1. , 0. , 1. , - 2. , 3. ], dtype=float32) - - Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to - specify concatenation axis, minimum number of dimensions, and the position of the - upgraded array's original dimensions in the resulting array's shape tuple: - - >>> bm.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - >>> bm.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Negative values for ``trans1d`` offset the last axis towards the start - of the shape tuple: - - >>> bm.r_['0,2,-2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with an extra row or column axis, respectively: - - >>> bm.r_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32) - - >>> bm.r_['c',[1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"`` - give the same result. - """ - axis = 0 - ndmin = 1 - trans1d = -1 - op_name = "r_" - - -r_ = RClass() - - -class CClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the last axis. - - LAX-backend implementation of :obj:`numpy.c_`. - - See Also: - ``jnp.r_``: Concatenates slices, scalars and array-like objects along the first axis. - - Examples: - - >>> a = bm.arange(6).reshape((2,3)) - >>> bm.c_[a,a] - DeviceArray([[0, 1, 2, 0, 1, 2], - [3, 4, 5, 3, 4, 5]], dtype=int32) - - Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify - concatenation axis, minimum number of dimensions, and the position of the upgraded array's - original dimensions in the resulting array's shape tuple: - - >>> bm.c_['0,2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - >>> bm.c_['0,2,-1', [1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with inputs stacked along the last axis: - - >>> jnp.c_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 4], - [2, 5], - [3, 6]], dtype=int32) - """ - axis = -1 - ndmin = 2 - trans1d = 0 - op_name = "c_" - - -c_ = CClass() - -s_ = np.s_ - -index_exp = np.index_exp diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 3671755a9..33ee9f1b5 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -1,23 +1,15 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p, - mv_prob_uniform_p, - mv_prob_normal_p, - mv_prob_homo, +from brainpy._src.math.jitconn._matvec import (mv_prob_homo, mv_prob_uniform, - mv_prob_normal, _general_checking, raw_mv_prob_homo, raw_mv_prob_uniform, @@ -27,9 +19,8 @@ _mv_prob_normal_transpose, _reverse) from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError ti = import_taichi() @@ -50,7 +41,9 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, + return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) @@ -68,7 +61,9 @@ def event_mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) @@ -86,651 +81,11 @@ def event_mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -### BRAINPYLIB ### - -def event_mv_prob_homo_brainpylib( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - return r - - -event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform_brainpylib( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal_brainpylib( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ - - -def _event_matvec_prob_homo_abstract( - events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_homo_cpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_homo' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_homo_gpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_homo_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1], ) - - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, weight, clen, seed = primals - event_dot, weight_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(weight_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(event_dot) is ad.Zero: - raise ValueError - dr = mv_prob_homo_p.bind(event_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(event_dot) is ad.Zero: - dr = mv_prob_homo_p.bind(events, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - dr = mv_prob_homo_p.bind(event_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, dr - - -def _event_matvec_prob_homo_transpose( - ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -event_mv_prob_homo_p = Primitive('event_mv_prob_homo') -event_mv_prob_homo_p.multiple_results = True -event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) -event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation -ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp -ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose -register_general_batching(event_mv_prob_homo_p) - - -def _event_matvec_prob_uniform_abstract( - events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_uniform_cpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_uniform_gpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_low, w_high, clen, seed = primals - events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - outdim_parallel=outdim_parallel, - transpose=transpose) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(events_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_uniform_transpose( - ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') -event_mv_prob_uniform_p.multiple_results = True -event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) -event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation -register_general_batching(event_mv_prob_uniform_p) -ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp -ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose - - -def _event_matvec_prob_normal_abstract( - events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - assert w_mu.dtype == w_sigma.dtype - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _get_types(event_shape): - event_type = event_shape.element_type() - if event_type == jnp.bool_: - event_type = b'_bool' - out_dtype = dtypes.canonicalize_dtype(float) - elif event_type == jnp.float32: - event_type = b'_float' - out_dtype = event_shape.element_type() - elif event_type == jnp.float64: - event_type = b'_double' - out_dtype = event_shape.element_type() - else: - raise TypeError - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - return out_dtype, event_type, type_name - - -def _event_matvec_prob_normal_cpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_normal' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_normal_gpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_normal_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_mu, w_sigma, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - + outdim_parallel=outdim_parallel) -def _event_matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_mu, w_sigma, clen, seed = primals - events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(events_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_normal_transpose( - ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -event_mv_prob_normal_p = Primitive('event_mv_prob_normal') -event_mv_prob_normal_p.multiple_results = True -event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) -event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation -register_general_batching(event_mv_prob_normal_p) -ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp -ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose - - -### TAICHI ### def event_mv_prob_homo_taichi( events: jax.Array, @@ -790,6 +145,9 @@ def event_mv_prob_homo_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + events = as_jax(events) if isinstance(weight, float): weight = as_jax(weight) weight = jnp.atleast_1d(as_jax(weight)) @@ -799,8 +157,10 @@ def event_mv_prob_homo_taichi( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return raw_event_mv_prob_homo(events, weight, conn_len, seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] def event_mv_prob_uniform_taichi( @@ -864,6 +224,9 @@ def event_mv_prob_uniform_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + events = as_jax(events) if isinstance(w_low, float): w_low = as_jax(w_low) if isinstance(w_high, float): w_high = as_jax(w_high) @@ -940,6 +303,9 @@ def event_mv_prob_normal_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + events = as_jax(events) if isinstance(w_mu, float): w_mu = as_jax(w_mu) if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) @@ -955,1034 +321,1033 @@ def event_mv_prob_normal_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col]: + r += weight0 key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] = r + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _reverse(shape): - return shape[::-1] - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _reverse(shape): + return shape[::-1] + + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col] != 0.: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + i_col += inc + out[i_row] += r # TODO: warp-level reduction -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction + def _event_mv_prob_homo_jvp_events( + evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(evt_dot, weight, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_mv_prob_homo_jvp_weight( + w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(events, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + + def raw_event_mv_prob_homo( + events: jax.Array, + weight: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_outdim_parallel_bool_p + else: + prim = _event_mv_prob_homo_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_bool_p + else: + prim = _event_mv_prob_homo_p + + return prim(events, + weight, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_homo_jvp_events, + _event_mv_prob_homo_jvp_weight, + None, + None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_bool_cpu, + gpu_kernel=_event_mv_prob_homo_bool_gpu + ) -def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_cpu, + gpu_kernel=_event_mv_prob_homo_gpu + ) -def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu -) - - -@ti.kernel -def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_uniform_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_uniform_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_uniform_jvp_events( + evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_low( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_high( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_uniform( + events: jax.Array, + w_low: jax.Array, # vector with size 1 + w_high: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_outdim_parallel_bool_p + else: + prim = _event_mv_prob_uniform_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_bool_p + else: + prim = _event_mv_prob_uniform_p + + return prim(events, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_uniform_jvp_events, + _event_mv_prob_uniform_jvp_w_low, + _event_mv_prob_uniform_jvp_w_high, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_bool_gpu + ) -def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_cpu, + gpu_kernel=_event_mv_prob_uniform_gpu + ) -def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu -) - - -@ti.kernel -def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_normal_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_normal_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_normal_jvp_events( + evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_mu( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_sigma( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_normal( + events: jax.Array, + w_mu: jax.Array, # vector with size 1 + w_sigma: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_outdim_parallel_bool_p + else: + prim = _event_mv_prob_normal_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_bool_p + else: + prim = _event_mv_prob_normal_p + + return prim(events, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_normal_jvp_events, + _event_mv_prob_normal_jvp_w_mu, + _event_mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_bool_cpu, + gpu_kernel=_event_mv_prob_normal_bool_gpu + ) -def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu + ) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu -) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_cpu, + gpu_kernel=_event_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 0caa9c996..84abb9805 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -1,22 +1,18 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional, Union import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError ti = import_taichi() @@ -215,808 +211,11 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -### BRAINYPLIB ### - -def mv_prob_homo_brainpylib( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - )[0] - - -def mv_prob_uniform_brainpylib( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_brainpylib( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def _matvec_prob_homo_abstract( - vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - if transpose: - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({vector.shape[0]},) @ mat {shape}.') - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_homo_cpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - out_type = b'_float' - elif out_dtype == jnp.float64: - out_type = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_homo' + out_type - else: - fn = b'cpu_matvec_atomic_prob_homo' + out_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_homo_gpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_homo_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_homo_v2' + type_name - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, weight, clen, seed = primals - vector_dot, weight_dot, clen_dot, seed_dot = tangents - r = mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(vector_dot) is ad.Zero: - raise ValueError - r_dot = mv_prob_homo_p.bind(vector_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(vector_dot) is ad.Zero: - r_dot = mv_prob_homo_p.bind(vector, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - r_dot = mv_prob_homo_p.bind(vector_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - return r, r_dot - - -def _matvec_prob_homo_transpose( - ct, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - assert type(vector) is ad.UndefinedPrimal - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -mv_prob_homo_p = Primitive('matvec_prob_homo') -mv_prob_homo_p.multiple_results = True -mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract) -mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation -register_general_batching(mv_prob_homo_p) -ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp -ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose - - -def _matvec_prob_uniform_abstract( - vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype == vector.dtype - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_uniform_cpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_uniform' + type_name - else: - fn = b'cpu_matvec_atomic_prob_uniform' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_uniform_gpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_uniform_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_uniform_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_low, w_high, clen, seed = primals - vector_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(vector_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -mv_prob_uniform_p = Primitive('matvec_prob_uniform') -mv_prob_uniform_p.multiple_results = True -mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract) -mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation -register_general_batching(mv_prob_uniform_p) -ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp -ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose - - -def _matvec_prob_normal_abstract( - vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _w_sigma_dtype in [jnp.float32, jnp.float64], '"w_sigma" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_normal_cpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_normal' + type_name - else: - fn = b'cpu_matvec_atomic_prob_normal' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_normal_gpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - event_shape = c.get_shape(vector) - out_dtype = event_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_normal_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_normal_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed,), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_mu, w_sigma, clen, seed = primals - vector_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(vector_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -mv_prob_normal_p = Primitive('matvec_prob_normal') -mv_prob_normal_p.multiple_results = True -mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract) -mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation -register_general_batching(mv_prob_normal_p) -ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp -ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose - - -### TAICHI ### def mv_prob_homo_taichi( vector: Union[Array, jax.Array], weight: float, @@ -1081,6 +280,9 @@ def mv_prob_homo_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + vector = as_jax(vector) if isinstance(weight, float): weight = as_jax(weight, dtype=vector.dtype) @@ -1157,6 +359,9 @@ def mv_prob_uniform_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + vector = as_jax(vector) if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) @@ -1233,6 +438,9 @@ def mv_prob_normal_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + vector = as_jax(vector) if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) @@ -1252,654 +460,657 @@ def _reverse(shape): return shape[::-1] -@ti.kernel -def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - -@ti.kernel -def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + @ti.kernel + def _mv_prob_homo_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + v = vector[i_col] * weight0 + while i_row < num_row: + out[i_row] += v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r * weight0 + + + @ti.kernel + def _mv_prob_homo_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] + while i_row < end: + out[i_row] += weight0 * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction + while i_col < end_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += weight0 * r # TODO: warp-level reduction -def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed + def _mv_prob_homo_transpose( + ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), weight, clen, seed + else: + dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, weight, clen, seed + elif ad.is_undefined_primal(weight): + if type(ct) is ad.Zero: + return vector, ad.Zero(weight), clen, seed + else: + row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + dw = jnp.sum(row * vector, keepdims=True) + return vector, dw, clen, seed else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + + def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + if vector.ndim != 1: + raise ValueError('vector should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + + assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + + for weight in weights: + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + out_shape = (shape[1],) + if vector.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') + shape = _reverse(shape) else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - -# outdim_parallel = False -_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - -@ti.kernel -def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + if vector.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') + out_shape = (shape[0],) + return shape, out_shape -@ti.kernel -def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) -@ti.kernel -def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction + def raw_mv_prob_homo( + vector: jax.Array, + weight: jax.Array, # vector with size 1 + clen: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) -def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_low, w_high, clen, seed + if outdim_parallel: + prim = _mv_prob_homo_outdim_parallel_p else: - dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_low, w_high, clen, seed - else: - assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' - assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p - else: - prim = _mv_prob_uniform_p - - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu -) - - -@ti.kernel -def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v + prim = _mv_prob_homo_p + + return prim(vector, + weight, + clen, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) + + # outdim_parallel = False + _mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, + gpu_kernel=_mv_prob_homo_gpu) + + + @ti.kernel + def _mv_prob_uniform_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_uniform_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction -def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - + def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_uniform_transpose( + ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_low, w_high, clen, seed + else: + dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_low, w_high, clen, seed + else: + assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' + assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + + def raw_mv_prob_uniform( + vector: jax.Array, + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + prim = _mv_prob_uniform_outdim_parallel_p + else: + prim = _mv_prob_uniform_p + + return prim(vector, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_uniform_jvp_vector, + _mv_prob_uniform_jvp_wlow, + _mv_prob_uniform_jvp_whigh, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu + ) -def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + # outdim_parallel = False + _mv_prob_uniform_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_cpu, + gpu_kernel=_mv_prob_uniform_gpu + ) -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed + @ti.kernel + def _mv_prob_normal_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_normal_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_transpose( + ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_mu, w_sigma, clen, seed + else: + dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_mu, w_sigma, clen, seed else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - + assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' + assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + + def raw_mv_prob_normal( + vector: jax.Array, + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + prim = _mv_prob_normal_outdim_parallel_p + else: + prim = _mv_prob_normal_p + + return prim(vector, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_normal_jvp_vector, + _mv_prob_normal_jvp_w_mu, + _mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_mv_prob_normal_outdim_parallel_gpu + ) -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p - - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu -) + # outdim_parallel = False + _mv_prob_normal_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_cpu, + gpu_kernel=_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index b10d55d21..034885ae9 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -4,8 +4,14 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) + shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 2e6e406cf..caee4efbe 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -4,8 +4,13 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index fb76aed24..fd7a289ed 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -10,7 +10,6 @@ from .utils import _shape_to_layout - __all__ = [ 'register_numba_xla_cpu_translation_rule', 'register_numba_mlir_cpu_translation_rule', diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d45f2c80b..6c13ac19a 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,8 @@ -from ._coo_mv import * +# from ._coo_mv import * +# from ._bsr_mv import * from ._csr_mv import * from ._utils import * -from ._bsr_mv import * from ._bsr_mm import * from ._jax_prim import * diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 2c75f0901..418a52d35 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -5,10 +5,14 @@ import jax from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm -# bm.set_platform('gpu') +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) seed = 1234 diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py index b73217496..23a3de93a 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv_old.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv_old.py @@ -4,16 +4,12 @@ import jax import pytest -from absl.testing import parameterized -import platform + import brainpy as bp import brainpy.math as bm pytest.skip('Old implementation.', allow_module_level=True) -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py index a9ee39f4a..928cb345a 100644 --- a/brainpy/_src/math/tifunc.py +++ b/brainpy/_src/math/tifunc.py @@ -3,362 +3,368 @@ ti = import_taichi() -__all__ = [ - # taichi function for other utilities - 'warp_reduce_sum', +if ti is not None: - # taichi functions for random number generator with LFSR88 algorithm - 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn', - 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand', + __all__ = [ + # taichi function for other utilities + 'warp_reduce_sum', - # taichi functions for random number generator with LFSR113 algorithm - 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn', - 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', -] + # taichi functions for random number generator with LFSR88 algorithm + 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn', + 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand', + # taichi functions for random number generator with LFSR113 algorithm + 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn', + 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', + ] -@ti.func -def _lcg_rand(state: ti.types.ndarray(ndim=1)): - # LCG constants - state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223) - return state[0] + @ti.func + def _lcg_rand(state: ti.types.ndarray(ndim=1)): + # LCG constants + state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223) + return state[0] -@ti.func -def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)): - """ - Generate a random number using the Taichi LCG algorithm. - Parameters: - seed (ti.types.ndarray): The seed value for the random number generator. + @ti.func + def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)): + """ + Generate a random number using the Taichi LCG algorithm. - Returns: - float: A random number between 0 and 1. - """ + Parameters: + seed (ti.types.ndarray): The seed value for the random number generator. - return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1) + Returns: + float: A random number between 0 and 1. + """ + return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1) -############################################# -# Random Number Generator: LFSR88 algorithm # -############################################# + ############################################# + # Random Number Generator: LFSR88 algorithm # + ############################################# -@ti.func -def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. + @ti.func + def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c + This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3 MUST be larger than - 1, 7, and 15 respectively. - */ + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c - Args: - seed: int. The seed value for the random number generator. + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3 MUST be larger than + 1, 7, and 15 respectively. + */ - Returns: - ti.math.uvec4: The random key for the LFSR88 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) + Args: + seed: int. The seed value for the random number generator. + Returns: + ti.math.uvec4: The random key for the LFSR88 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) -@ti.func -def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - Args: - key: The state value for the random number generator. + @ti.func + def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - Returns: - ti.math.uvec4: The next random key. - """ - b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) - s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b - b = ((key[1] << 2) ^ key[1]) >> 25 - s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b - b = ((key[2] << 3) ^ key[2]) >> 11 - s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b - return ti.math.uvec4(s1, s2, s3, b) + Args: + key: The state value for the random number generator. + Returns: + ti.math.uvec4: The next random key. + """ + b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) + s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b + b = ((key[1] << 2) ^ key[1]) >> 25 + s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b + b = ((key[2] << 3) ^ key[2]) >> 11 + s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b + return ti.math.uvec4(s1, s2, s3, b) -@ti.func -def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + @ti.func + def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. - key, r = lfsr88_randn(key, epsilon) - return key, mu + sigma * r + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ + key, r = lfsr88_randn(key, epsilon) + return key, mu + sigma * r -@ti.func -def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with the standard normal distribution using the LFSR88 algorithm. - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). + @ti.func + def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with the standard normal distribution using the LFSR88 algorithm. - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - """ + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - key, u1 = lfsr88_rand(key) - key, u2 = lfsr88_rand(key) + """ - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) + key, u1 = lfsr88_rand(key) + key, u2 = lfsr88_rand(key) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - return key, z2 + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + return key, z2 -@ti.func -def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) + @ti.func + def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) -@ti.func -def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): - key = lfsr88_next_key(key) - return key, dtype(key[0] ^ key[1] ^ key[2]) + @ti.func + def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): + key = lfsr88_next_key(key) + return key, dtype(key[0] ^ key[1] ^ key[2]) -@ti.func -def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) + @ti.func + def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) -@ti.func -def lfsr88_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. - Args: - key: The state value used for random number generation. - """ - key = lfsr88_next_key(key) - return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + @ti.func + def lfsr88_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. + Args: + key: The state value used for random number generation. + """ + key = lfsr88_next_key(key) + return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -############################################## -# Random Number Generator: LFSR113 algorithm # -############################################## + ############################################## + # Random Number Generator: LFSR113 algorithm # + ############################################## -@ti.func -def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. + @ti.func + def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c + This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3, s4 MUST be larger than - 1, 7, 15, and 127 respectively. - */ + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c - Args: - seed: int. The seed value for the random number generator. + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3, s4 MUST be larger than + 1, 7, 15, and 127 respectively. + */ - Returns: - ti.math.uvec4: The random key for the LFSR113 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) + Args: + seed: int. The seed value for the random number generator. + Returns: + ti.math.uvec4: The random key for the LFSR113 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) -@ti.func -def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - Args: - key: The state value for the random number generator. + @ti.func + def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - Returns: - ti.math.uvec4: The next random key. - """ - z1 = key[0] - z2 = key[1] - z3 = key[2] - z4 = key[3] - b = ((z1 << 6) ^ z1) >> 13 - z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) - b = ((z2 << 2) ^ z2) >> 27 - z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) - b = ((z3 << 13) ^ z3) >> 21 - z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) - b = ((z4 << 3) ^ z4) >> 12 - z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) - return ti.math.uvec4(z1, z2, z3, z4) + Args: + key: The state value for the random number generator. + Returns: + ti.math.uvec4: The next random key. + """ + z1 = key[0] + z2 = key[1] + z3 = key[2] + z4 = key[3] + b = ((z1 << 6) ^ z1) >> 13 + z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) + b = ((z2 << 2) ^ z2) >> 27 + z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) + b = ((z3 << 13) ^ z3) >> 21 + z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) + b = ((z4 << 3) ^ z4) >> 12 + z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) + return ti.math.uvec4(z1, z2, z3, z4) -@ti.func -def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + @ti.func + def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. - key, r = lfsr113_randn(key, epsilon) - return key, ti.cast(mu + sigma * r, defaults.ti_float) + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ + key, r = lfsr113_randn(key, epsilon) + return key, ti.cast(mu + sigma * r, defaults.ti_float) -@ti.func -def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with standard normal distribution using the LFSR113 algorithm. - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). + @ti.func + def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with standard normal distribution using the LFSR113 algorithm. - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - """ + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - key, u1 = lfsr113_rand(key) - key, u2 = lfsr113_rand(key) + """ - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) + key, u1 = lfsr113_rand(key) + key, u2 = lfsr113_rand(key) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - return key, z2 + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + return key, z2 -@ti.func -def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr113_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) + @ti.func + def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr113_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) -@ti.func -def lfsr113_randint(key: ti.types.vector(4, ti.u32)): - key = lfsr113_next_key(key) - return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) + @ti.func + def lfsr113_randint(key: ti.types.vector(4, ti.u32)): + key = lfsr113_next_key(key) + return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) -@ti.func -def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) - - -@ti.func -def lfsr113_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. + @ti.func + def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. - Args: - key: The state value used for random number generation. - """ - key = lfsr113_next_key(key) - return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) -########################### -# Reductions: warp reduce # -########################### + @ti.func + def lfsr113_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. + Args: + key: The state value used for random number generation. + """ + key = lfsr113_next_key(key) + return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -@ti.func -def warp_reduce_sum_all(val): - """ - Warp reduce sum. - Args: - val (float): The value to be reduced. + ########################### + # Reductions: warp reduce # + ########################### - Returns: - float: The reduced value. - """ - for i in ti.static(range(1, 32)): - val += ti.static(ti.simt.warp.shfl_xor(val, i)) - return val + @ti.func + def warp_reduce_sum_all(val): + """ + Warp reduce sum. -@ti.func -def warp_reduce_sum(val): - """ - Warp reduce sum. + Args: + val (float): The value to be reduced. - Args: - val (float): The value to be reduced. + Returns: + float: The reduced value. + """ + for i in ti.static(range(1, 32)): + val += ti.static(ti.simt.warp.shfl_xor(val, i)) + return val - Returns: - float: The reduced value. - """ - for offset in ti.static((16, 8, 4, 2, 1)): - val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) - return val + + @ti.func + def warp_reduce_sum(val): + """ + Warp reduce sum. + + Args: + val (float): The value to be reduced. + + Returns: + float: The reduced value. + """ + for offset in ti.static((16, 8, 4, 2, 1)): + val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) + return val + + +else: + __all__ = [] diff --git a/brainpy/errors.py b/brainpy/errors.py index e59bb326c..37d4b9488 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -38,7 +38,16 @@ class AnalyzerError(BrainPyError): class PackageMissingError(BrainPyError): """The package missing error. """ - pass + + def __init__(self, name: str = None, purpose: str = None): + + if name is None: + super().__init__() + else: + assert purpose, '"purpose" cannot be None when "name" is provided.' + msg = (f'"{name}" must be installed when the user wants to use {purpose}. \n' + f'Please install through "pip install {name}".') + super().__init__(msg) class BackendNotInstalled(BrainPyError): @@ -236,9 +245,5 @@ def __init__(self, name): ''') - - class SharedArgError(BrainPyError): pass - - diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 0a17cae7c..43d89c1b2 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,5 +1,4 @@ from brainpy._src.math.event import ( csrmv as csrmv, - info as info, ) diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 1380a9e9c..fbe0acbf2 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,6 +1,5 @@ from brainpy._src.math.sparse import ( csrmv, - coomv, seg_matmul, diff --git a/requirements.txt b/requirements.txt index 02fdebe83..ab5665e73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ numpy jax tqdm -numba -taichi==1.7.0 diff --git a/setup.py b/setup.py index d7fd45e38..766cd8c75 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'numba', 'taichi==1.7.0'], + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", @@ -68,11 +68,10 @@ 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', ], extras_require={ - 'cpu': ['jaxlib>=0.4.13', 'brainpylib'], - 'cuda': ['jax[cuda]', 'brainpylib-cu12x'], - 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'], - 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'], - 'tpu': ['jax[tpu]'], + 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], + 'cuda11': ['jax[cuda11_pip]', 'brainpylib-cu11x', 'numba', 'taichi==1.7.0'], + 'cuda12': ['jax[cuda12_pip]', 'brainpylib-cu12x', 'numba', 'taichi==1.7.0'], + 'tpu': ['jax[tpu]', 'numba', 'taichi==1.7.0'], }, keywords=('computational neuroscience, ' 'brain-inspired computation, ' From 947a380575839ed2f1a2ae847040505e2e59607a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 17 Feb 2024 22:49:53 +0800 Subject: [PATCH 02/16] [math] Update operator selection strategy for csr matvec --- brainpy/_src/dnn/linear.py | 6 +- .../_src/dyn/projections/tests/test_STDP.py | 240 ++--- brainpy/_src/math/sparse/_csr_mv.py | 843 +++++------------- brainpy/math/__init__.py | 200 ++--- 4 files changed, 421 insertions(+), 868 deletions(-) diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 539214d3b..6a37bdcba 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -682,8 +682,7 @@ def __init__( def update(self, x): if x.ndim == 1: return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -694,8 +693,7 @@ def update(self, x): def _batch_csrmv(self, x): return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index b8884f327..e78ae5048 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,120 +1,120 @@ -# -*- coding: utf-8 -*- - - -import numpy as np -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class Test_STDP(parameterized.TestCase): - - @parameterized.product( - comm_method=['dense', 'csr', 'masked_linear', 'all2all', 'one2one'], - delay=[None, 0., 2.], - syn_model=['exp', 'dual_exp', 'ampa'], - out_model=['cuba', 'coba', 'mg'] - ) - def test_STDP(self, comm_method, delay, syn_model, out_model): - bm.random.seed() - - class STDPNet(bp.DynamicalSystem): - def __init__(self, num_pre, num_post): - super().__init__() - self.pre = bp.dyn.LifRef(num_pre) - self.post = bp.dyn.LifRef(num_post) - - if comm_method == 'all2all': - comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'csr': - if syn_model == 'exp': - comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - else: - comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'masked_linear': - comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'dense': - comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'one2one': - comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) - else: - raise ValueError - - if syn_model == 'exp': - syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) - elif syn_model == 'dual_exp': - syn = bp.dyn.DualExpon.desc(self.post.varshape) - elif syn_model == 'dual_exp_v2': - syn = bp.dyn.DualExponV2.desc(self.post.varshape) - elif syn_model == 'ampa': - syn = bp.dyn.AMPA.desc(self.post.varshape) - else: - raise ValueError - - if out_model == 'cuba': - out = bp.dyn.CUBA.desc() - elif out_model == 'coba': - out = bp.dyn.COBA.desc(E=0.) - elif out_model == 'mg': - out = bp.dyn.MgBlock.desc(E=0.) - else: - raise ValueError - - self.syn = bp.dyn.STDP_Song2000( - pre=self.pre, - delay=delay, - comm=comm, - syn=syn, - out=out, - post=self.post, - tau_s=16.8, - tau_t=33.7, - A1=0.96, - A2=0.53, - W_min=0., - W_max=1. - ) - - def update(self, I_pre, I_post): - self.syn() - self.pre(I_pre) - self.post(I_post) - conductance = self.syn.refs['syn'].g - Apre = self.syn.refs['pre_trace'].g - Apost = self.syn.refs['post_trace'].g - current = self.post.sum_current_inputs(self.post.V) - if comm_method == 'dense': - w = self.syn.comm.W.flatten() - else: - w = self.syn.comm.weight.flatten() - return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w - - duration = 300. - I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) - I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) - - net = STDPNet(1, 1) - - def run(i, I_pre, I_post): - pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) - return pre_spike, post_spike, g, Apre, Apost, current, W - - indices = np.arange(int(duration / bm.dt)) - pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) - - # import matplotlib.pyplot as plt - # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) - # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) - # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) - # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) - # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) - # plt.show() - - bm.clear_buffer_memory() - +# -*- coding: utf-8 -*- + + +import numpy as np +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +bm.set_platform('cpu') +class Test_STDP(parameterized.TestCase): + + @parameterized.product( + comm_method=['dense', 'csr', 'masked_linear', 'all2all', 'one2one'], + delay=[None, 0., 2.], + syn_model=['exp', 'dual_exp', 'ampa'], + out_model=['cuba', 'coba', 'mg'] + ) + def test_STDP(self, comm_method, delay, syn_model, out_model): + bm.random.seed() + + class STDPNet(bp.DynamicalSystem): + def __init__(self, num_pre, num_post): + super().__init__() + self.pre = bp.dyn.LifRef(num_pre) + self.post = bp.dyn.LifRef(num_post) + + if comm_method == 'all2all': + comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'csr': + if syn_model == 'exp': + comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + else: + comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'masked_linear': + comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'dense': + comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'one2one': + comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) + else: + raise ValueError + + if syn_model == 'exp': + syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) + elif syn_model == 'dual_exp': + syn = bp.dyn.DualExpon.desc(self.post.varshape) + elif syn_model == 'dual_exp_v2': + syn = bp.dyn.DualExponV2.desc(self.post.varshape) + elif syn_model == 'ampa': + syn = bp.dyn.AMPA.desc(self.post.varshape) + else: + raise ValueError + + if out_model == 'cuba': + out = bp.dyn.CUBA.desc() + elif out_model == 'coba': + out = bp.dyn.COBA.desc(E=0.) + elif out_model == 'mg': + out = bp.dyn.MgBlock.desc(E=0.) + else: + raise ValueError + + self.syn = bp.dyn.STDP_Song2000( + pre=self.pre, + delay=delay, + comm=comm, + syn=syn, + out=out, + post=self.post, + tau_s=16.8, + tau_t=33.7, + A1=0.96, + A2=0.53, + W_min=0., + W_max=1. + ) + + def update(self, I_pre, I_post): + self.syn() + self.pre(I_pre) + self.post(I_post) + conductance = self.syn.refs['syn'].g + Apre = self.syn.refs['pre_trace'].g + Apost = self.syn.refs['post_trace'].g + current = self.post.sum_current_inputs(self.post.V) + if comm_method == 'dense': + w = self.syn.comm.W.flatten() + else: + w = self.syn.comm.weight.flatten() + return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w + + duration = 300. + I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) + I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) + + net = STDPNet(1, 1) + + def run(i, I_pre, I_post): + pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) + return pre_spike, post_spike, g, Apre, Apost, current, W + + indices = np.arange(int(duration / bm.dt)) + pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) + + # import matplotlib.pyplot as plt + # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) + # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) + # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) + # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) + # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) + # plt.show() + + bm.clear_buffer_memory() + diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 377597579..27f10f4b9 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -7,10 +7,12 @@ import jax import numba import numpy as np +import brainpy.math as bm from jax import core, dtypes from jax import numpy as jnp from jax.interpreters import ad, mlir, xla from jax.lib import xla_client +from jax.experimental.sparse import csr from jaxlib import gpu_sparse from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi @@ -20,7 +22,7 @@ register_general_batching, XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound +from brainpy.errors import PackageMissingError ti = import_taichi() @@ -37,7 +39,6 @@ def csrmv( *, shape: Tuple[int, int], transpose: bool = False, - method: str = None, ): """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. @@ -76,455 +77,7 @@ def csrmv( The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ - if method is None: - return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) - else: - return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method) - - -### BRAINPYLIB ### - -def csrmv_brainpylib( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, - method: str = 'cusparse', -): - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - method: str - The method used to compute Matrix-Vector Multiplication. The candidate methods are: - - - ``cusparse``: using cuSPARSE library. - - ``scalar``: - - ``vector``: - - ``adaptive``: - - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - - data = jnp.atleast_1d(as_jax(data)) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) - - if vector.dtype == jnp.bool_: - vector = as_jax(vector, dtype=data.dtype) - - if method == 'cusparse': - if jax.default_backend() == 'gpu': - if data.shape[0] == 1: - data = jnp.ones(indices.shape, dtype=data.dtype) * data - if indices.dtype in [jnp.uint32, jnp.uint64]: - indices = jnp.asarray(indices, dtype=dtypes.canonicalize_dtype(jnp.int64)) - if indptr.dtype in [jnp.uint32, jnp.uint64]: - indptr = jnp.asarray(indptr, dtype=dtypes.canonicalize_dtype(jnp.int64)) - return _csrmv_cusparse_p.bind(data, - indices, - indptr, - vector, - shape=shape, - transpose=transpose) - - elif method == 'adaptive': - return _csrmv_adaptive_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - elif method == 'scalar': - return _csrmv_scalar_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - elif method == 'vector': - return _csrmv_vector_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - else: - raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.') - - -def _csrmv_abstract(data, indices, indptr, vector, *, shape, transpose): - if data.dtype not in [jnp.float32, jnp.float64]: - raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - out_shape = shape[1] if transpose else shape[0] - return core.ShapedArray((out_shape,), data.dtype) - - -@numba.njit(fastmath=True) -def _csr_matvec_transpose_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, col_indices, row_ptr, vector, shape, _ = ins - # (csr mat).T @ vec - - if values.shape[0] == 1: - values = values[0] - for row_i in range(shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += values * v - else: - for row_i in range(shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += v * values[j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _csr_matvec_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, col_indices, row_ptr, vector, shape, _ = ins - # csr mat @ vec - if values.shape[0] == 1: - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values * vector[col_indices[j]] - res_val[row_i] = r - else: - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - res_val[row_i] = r - - -def _csrmv_cpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - inputs = (data, indices, indptr, vector) - description = dict(shape=shape, transpose=transpose) - if transpose: - target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba( - c, - _csr_matvec_transpose_numba_imp, - _csrmv_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - else: - target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba( - c, - _csr_matvec_numba_imp, - _csrmv_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_layouts, - shape_with_layout=output_layouts, - ) - - -def _csrmv_cusparse_gpu_lowering(ctx, data, indices, indptr, vector, *, shape, transpose): - data_aval, indices_aval, _, v_aval = ctx.avals_in - dtype = data_aval.dtype - if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: - raise TypeError(f"cusparse_csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. " - "Falling back to default implementation.") - return [gpu_sparse.cuda_csr_matvec(data, indices, indptr, vector, - shape=shape, - transpose=transpose, - data_dtype=dtype, - x_dtype=v_aval.dtype, - index_dtype=indices_aval.dtype)] - - -def _csrmv_jvp_mat(csr_prim, data_dot, data, indices, indptr, v, *, shape, transpose): - return csr_prim.bind(data_dot, indices, indptr, v, shape=shape, transpose=transpose) - - -def _csrmv_jvp_vec(prim, v_dot, data, indices, indptr, v, *, shape, transpose): - return prim.bind(data, indices, indptr, v_dot, shape=shape, transpose=transpose) - - -def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return data, indices, indptr, ad.Zero(vector) - else: - ct_vector = _csrmv_cusparse_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, ct_vector - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_cusparse_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_cusparse_p = core.Primitive('cusparse_csr_matvec') -_csrmv_cusparse_p.def_abstract_eval(_csrmv_abstract) -_csrmv_cusparse_p.def_impl(partial(xla.apply_primitive, _csrmv_cusparse_p)) -# xla.backend_specific_translations['cpu'][_csrmv_cusparse_p] = _csrmv_cpu_translation -ad.defjvp(_csrmv_cusparse_p, - partial(_csrmv_jvp_mat, _csrmv_cusparse_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_cusparse_p)) -ad.primitive_transposes[_csrmv_cusparse_p] = _csrmv_cusparse_transpose -register_general_batching(_csrmv_cusparse_p) -mlir.register_lowering(_csrmv_cusparse_p, _csrmv_cusparse_gpu_lowering, platform='cuda') - - -def _csr_matvec_scalar_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_scalar_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_scalar' + ftype + itype, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_scalar_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_scalar_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_scalar_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_scalar_p = core.Primitive('csr_matvec_scalar') -_csrmv_scalar_p.def_abstract_eval(_csrmv_abstract) -_csrmv_scalar_p.def_impl(partial(xla.apply_primitive, _csrmv_scalar_p)) -# xla.backend_specific_translations['cpu'][_csrmv_scalar_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_scalar_p] = _csr_matvec_scalar_gpu_translation -ad.defjvp(_csrmv_scalar_p, - partial(_csrmv_jvp_mat, _csrmv_scalar_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_scalar_p), ) -ad.primitive_transposes[_csrmv_scalar_p] = _csrmv_scalar_transpose -register_general_batching(_csrmv_scalar_p) - - -def _csr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_vector_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_vector' + ftype + itype, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_vector_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_vector_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_vector_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_vector_p = core.Primitive('csr_matvec_vector') -_csrmv_vector_p.def_abstract_eval(_csrmv_abstract) -_csrmv_vector_p.def_impl(partial(xla.apply_primitive, _csrmv_vector_p)) -# xla.backend_specific_translations['cpu'][_csrmv_vector_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_vector_p] = _csr_matvec_vector_gpu_translation -ad.defjvp(_csrmv_vector_p, - partial(_csrmv_jvp_mat, _csrmv_vector_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_vector_p), ) -ad.primitive_transposes[_csrmv_vector_p] = _csrmv_vector_transpose -register_general_batching(_csrmv_vector_p) - - -def _csr_matvec_adaptive_gpu_translation(c, data, indices, indptr, row_blocks, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_adaptive_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_vector' + ftype + itype, - operands=(data, indices, indptr, row_blocks, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(row_blocks), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_adaptive_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_adaptive_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_adaptive_p = core.Primitive('csr_matvec_adaptive') -_csrmv_adaptive_p.def_abstract_eval(_csrmv_abstract) -_csrmv_adaptive_p.def_impl(partial(xla.apply_primitive, _csrmv_adaptive_p)) -# xla.backend_specific_translations['cpu'][_csrmv_adaptive_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_adaptive_p] = _csr_matvec_adaptive_gpu_translation -ad.defjvp(_csrmv_adaptive_p, - partial(_csrmv_jvp_mat, _csrmv_adaptive_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_adaptive_p), ) -ad.primitive_transposes[_csrmv_adaptive_p] = _csrmv_adaptive_transpose -register_general_batching(_csrmv_adaptive_p) - + return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) ### TAICHI ### @@ -592,172 +145,6 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] - -# ------------- -# CPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += value * vector[row_i] - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += vector[row_i] * values[j] - - -@ti.kernel -def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += vector[col_indices[j]] - out[row_i] = r * value - - -@ti.kernel -def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += value * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += vector[col_indices[j]] - j += 32 - out[row_i] += value * r - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += values[j] * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += values[j] * vector[col_indices[j]] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_transpose( - ct, data, indices, indptr, vector, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] - - return ct_data, indices, indptr, vector - - def raw_csrmv_taichi( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -767,17 +154,22 @@ def raw_csrmv_taichi( shape: Tuple[int, int], transpose: bool = False, ): + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') out_shape = shape[1] if transpose else shape[0] - if transpose: - if data.shape[0] == 1: - prim = _csr_matvec_transpose_homo_p + if data.shape[0] != 1: + if bm.get_platform() == 'gpu': + return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose), ] else: - prim = _csr_matvec_transpose_heter_p + if transpose: + prim = _csr_matvec_transpose_heter_p + else: + prim = _csr_matvec_heter_p else: - if data.shape[0] == 1: - prim = _csr_matvec_homo_p + if transpose: + prim = _csr_matvec_transpose_homo_p else: - prim = _csr_matvec_heter_p + prim = _csr_matvec_homo_p return prim(data, indices, @@ -788,25 +180,194 @@ def raw_csrmv_taichi( shape=shape) -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) - prim.def_transpose_rule(_sparse_csr_matvec_transpose) - return prim +if ti is not None: + + # ------------- + # CPU operators + # ------------- + @ti.kernel + def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += value * vector[row_i] + + + @ti.kernel + def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += vector[row_i] * values[j] + + + @ti.kernel + def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += vector[col_indices[j]] + out[row_i] = r * value + + + @ti.kernel + def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + out[row_i] = r + + + # ------------- + # GPU operators + # ------------- + + + @ti.kernel + def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += value * vector[row_i] + j += 32 + + + @ti.kernel + def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += vector[col_indices[j]] + j += 32 + out[row_i] += value * r + + + @ti.kernel + def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += values[j] * vector[row_i] + j += 32 + + + @ti.kernel + def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += values[j] * vector[col_indices[j]] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) + + + def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) + + + def _sparse_csr_matvec_transpose( + ct, data, indices, indptr, vector, *, outs, transpose, shape, + ): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(vector): + ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] + ct_data = jnp.inner(ct[0], ct_data) + else: + row, col = csr_to_coo(indices, indptr) + ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] + + return ct_data, indices, indptr, vector + + + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) + prim.def_transpose_rule(_sparse_csr_matvec_transpose) + return prim + + # transpose homo + _csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) -# transpose homo -_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) + # no transpose homo + _csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, + gpu_kernel=_sparse_csr_matvec_homo_gpu) -# no transpose homo -_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, - gpu_kernel=_sparse_csr_matvec_homo_gpu) + # transpose heter + _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) -# transpose heter -_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) + # no transpose heter + _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, + gpu_kernel=_sparse_csr_matvec_heter_gpu) -# no transpose heter -_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) + # heter cusparse + _csr_matvec_cusparse_p = csr.csr_matvec_p + register_general_batching(_csr_matvec_cusparse_p) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 02f671345..8bec65599 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -1,103 +1,97 @@ -# -*- coding: utf-8 -*- - -# data structure -from .ndarray import * -from .delayvars import * -from .interoperability import * -from .datatypes import * -from .compat_numpy import * -from .compat_tensorflow import * -from .compat_pytorch import * -from .einops import * - -# functions -from .activations import * -from . import activations - -# operators -from .pre_syn_post import * -from .op_register import * -from . import surrogate, event, sparse, jitconn - -# Variable and Objects for object-oriented JAX transformations -from .oo_transform import * - -# environment settings -from .modes import * -from .environment import * -from .scales import * -from .others import * - -# high-level numpy operations -from . import fft -from . import linalg -from . import random - -# taichi operations -from . import tifunc - -# others -from . import sharding - -import jax.numpy as jnp -from jax import config - -del jnp, config - -from brainpy._src.math.surrogate._compt import ( - spike_with_sigmoid_grad as spike_with_sigmoid_grad, - spike_with_linear_grad as spike_with_linear_grad, - spike_with_gaussian_grad as spike_with_gaussian_grad, - spike_with_mg_grad as spike_with_mg_grad, -) - -from brainpy._src.math import defaults -from brainpy._src.deprecations import deprecation_getattr -__deprecations = { - "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", - sparse.seg_matmul), - 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_homo instead.", - jitconn.event_mv_prob_homo), - 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", - jitconn.event_mv_prob_uniform), - 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_normal instead.", - jitconn.event_mv_prob_normal), - 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_homo instead.", - jitconn.mv_prob_homo), - 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_uniform instead.", - jitconn.mv_prob_uniform), - 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_normal instead.", - jitconn.mv_prob_normal), - 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " - "Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'cusparse_coo_matvec': ("brainpy.math.cusparse_coo_matvec is deprecated. " - "Use brainpy.math.sparse.coomv instead.", - sparse.coomv), - 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " - "Use brainpy.math.sparse.coo_to_csr instead.", - sparse.coo_to_csr), - 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " - "Use brainpy.math.sparse.csr_to_coo instead.", - sparse.csr_to_coo), - 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " - "Use brainpy.math.sparse.csr_to_dense instead.", - sparse.csr_to_dense), - 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " - "Use brainpy.math.event.csr_to_dense instead.", - event.csrmv), - 'event_info': ("brainpy.math.event_info is deprecated. " - "Use brainpy.math.event.info instead.", - event.info), -} - -__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) -del deprecation_getattr, defaults +# -*- coding: utf-8 -*- + +# data structure +from .ndarray import * +from .delayvars import * +from .interoperability import * +from .datatypes import * +from .compat_numpy import * +from .compat_tensorflow import * +from .compat_pytorch import * +from .einops import * + +# functions +from .activations import * +from . import activations + +# operators +from .pre_syn_post import * +from .op_register import * +from . import surrogate, event, sparse, jitconn + +# Variable and Objects for object-oriented JAX transformations +from .oo_transform import * + +# environment settings +from .modes import * +from .environment import * +from .scales import * +from .others import * + +# high-level numpy operations +from . import fft +from . import linalg +from . import random + +# taichi operations +from . import tifunc + +# others +from . import sharding + +import jax.numpy as jnp +from jax import config + +del jnp, config + +from brainpy._src.math.surrogate._compt import ( + spike_with_sigmoid_grad as spike_with_sigmoid_grad, + spike_with_linear_grad as spike_with_linear_grad, + spike_with_gaussian_grad as spike_with_gaussian_grad, + spike_with_mg_grad as spike_with_mg_grad, +) + +from brainpy._src.math import defaults +from brainpy._src.deprecations import deprecation_getattr +__deprecations = { + "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", + sparse.seg_matmul), + 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_homo instead.", + jitconn.event_mv_prob_homo), + 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", + jitconn.event_mv_prob_uniform), + 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_normal instead.", + jitconn.event_mv_prob_normal), + 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_homo instead.", + jitconn.mv_prob_homo), + 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_uniform instead.", + jitconn.mv_prob_uniform), + 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_normal instead.", + jitconn.mv_prob_normal), + 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " + "Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " + "Use brainpy.math.sparse.coo_to_csr instead.", + sparse.coo_to_csr), + 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " + "Use brainpy.math.sparse.csr_to_coo instead.", + sparse.csr_to_coo), + 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " + "Use brainpy.math.sparse.csr_to_dense instead.", + sparse.csr_to_dense), + 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " + "Use brainpy.math.event.csr_to_dense instead.", + event.csrmv), +} + +__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) +del deprecation_getattr, defaults From ecc6f9573fbadb10c29a8778f5f33319447bfa9b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 17 Feb 2024 22:56:19 +0800 Subject: [PATCH 03/16] [math] Remove old test case of event csr matvec and csr matvec --- .../math/event/tests/test_event_csrmv_old.py | 318 ---------------- .../_src/math/sparse/tests/test_csrmv_old.py | 348 ------------------ 2 files changed, 666 deletions(-) delete mode 100644 brainpy/_src/math/event/tests/test_event_csrmv_old.py delete mode 100644 brainpy/_src/math/sparse/tests/test_csrmv_old.py diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py deleted file mode 100644 index fcb25a89c..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv_old.py +++ /dev/null @@ -1,318 +0,0 @@ -# -*- coding: utf-8 -*- - - -from functools import partial - -import jax -import pytest - -import brainpy as bp -import brainpy.math as bm - -pytest.skip('Old implementation.', allow_module_level=True) - -brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') -taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -class Test_event_csr_matvec(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo(self, shape, transpose, homo_data): - print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - heter_data = bm.ones(indices.shape) * homo_data - - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r3)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r4)) - - r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r5)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo_vmap(self, shape, transpose, homo_data): - print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, - method='cusparse')) - vmap_data1 = bm.as_jax([homo_data] * 10) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}', - homo_data=homo_data, - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo_grad(self, shape, transpose, homo_data): - print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else - ((dense_conn * a) @ events))))(homo_data) - self.assertTrue(bm.allclose(r1, r3)) - - # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else - ((dense_conn * homo_data) @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - ) - def test_heter(self, shape, transpose): - print(f'test_heter: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r3)) - - r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f"transpose={transpose}, shape={shape}", - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - ) - def test_heter_vmap(self, shape, transpose): - print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'transpose={transpose},shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - ) - def test_heter_grad(self, shape, transpose): - print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) - r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else - (a @ events))))(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r3 = r3[rows, cols] - self.assertTrue(bm.allclose(r1, r3)) - - # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else - (dense_data @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py deleted file mode 100644 index 23a3de93a..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv_old.py +++ /dev/null @@ -1,348 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial - -import jax -import pytest - -import brainpy as bp -import brainpy.math as bm - -pytest.skip('Old implementation.', allow_module_level=True) - - -cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') -scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') -vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') - - -class Test_cusparse_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_cusparse_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.ones(indices.shape).value * homo_data - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r3)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - v=[-1., 0., 1.] - ) - def test_homo_vmap(self, transpose, shape, v): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = bm.ones((10, indices.shape[0])).value * v - homo_data = bm.ones(10).value * v - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(heter_data) - self.assertTrue(bm.allclose(r1, r2)) - - r3 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r3)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() - if transpose else - ((dense * a) @ vector).sum()), - argnums=0) - - r1 = csr_f1(homo_data) - r2 = dense_f1(homo_data) - self.assertTrue(bm.allclose(r1, r2)) - - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v, - shape=shape, transpose=transpose).sum()) - dense_data = dense * homo_data - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) - - r3 = csr_f2(vector) - r4 = dense_f2(vector) - self.assertTrue(bm.allclose(r3, r4)) - - csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v, - shape=shape, transpose=transpose).sum(), - argnums=(0, 1)) - dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() - if transpose else - ((dense * a) @ v).sum()), - argnums=(0, 1)) - - r5 = csr_f3(homo_data, vector) - r6 = dense_f3(homo_data, vector) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - ) - def test_heter(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector, - shape=shape, transpose=transpose) - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r2 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = rng.random((10, indices.shape[0])) - heter_data = bm.as_jax(heter_data) - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=shape))(heter_data) - - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(heter_data) - r2 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - - r1 = csr_f1(heter_data) - r2 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r2 = r2[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), - argnums=0) - r3 = csr_f2(vector) - r4 = dense_f2(vector) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - -class Test_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - homo_data=[-1., 0., 0.1, 1.], - shape=[(100, 200), (10, 1000), (2, 2000)], - ) - def test_homo(self, shape, homo_data): - conn = bp.conn.FixedProb(0.1) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(123) - vector = rng.random(shape[1]) - vector = bm.as_jax(vector) - - # csrmv - r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - - heter_data = bm.ones(indices.shape).to_jax() * homo_data - r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r4)) - self.assertTrue(bm.allclose(r1, r5)) - self.assertTrue(bm.allclose(r1, r6)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - rdense = dense @ vector - self.assertTrue(bm.allclose(r1, rdense)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = bm.as_jax(rng.random(indices.shape)) - vector = bm.as_jax(rng.random(shape[1])) - - r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = dense @ vector - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - heter_data = rng.random(indices.shape) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[1]) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - dense_f1 = jax.grad(lambda a: (a @ vector).sum()) - - r1 = csr_f1(heter_data) - r2 = csr_f2(heter_data) - r3 = csr_f3(heter_data) - - d1 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - d1 = d1[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, d1)) - - # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum()) - # r4 = csr_f4(vector) - # r5 = csr_f5(vector) - # r6 = csr_f6(vector) - # d2 = dense_f2(vector) - # self.assertTrue(bm.allclose(r4, r5)) - # self.assertTrue(bm.allclose(r4, r6)) - # self.assertTrue(bm.allclose(r4, d2)) - - bm.clear_buffer_memory() - - From 64958fce7a5f3cd951bacc133a76ec0361382dd5 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 22 Feb 2024 20:32:05 +0800 Subject: [PATCH 04/16] [dependency] remove all numba and taichi dependency --- brainpy/_src/connect/random_conn.py | 2760 +++++++++-------- brainpy/_src/dependency_check.py | 218 +- brainpy/_src/dnn/linear.py | 321 +- brainpy/_src/math/defaults.py | 7 +- brainpy/_src/math/environment.py | 4 +- brainpy/_src/math/event/__init__.py | 4 +- brainpy/_src/math/jitconn/__init__.py | 7 +- brainpy/_src/math/op_register/__init__.py | 22 +- .../op_register/numba_approach/__init__.py | 365 +-- .../numba_approach/cpu_translation.py | 301 +- brainpy/_src/math/sparse/__init__.py | 10 +- brainpy/_src/math/sparse/_csr_mv.py | 56 +- brainpy/_src/math/tifunc.py | 4 +- brainpy/math/__init__.py | 86 +- brainpy/math/event.py | 8 +- brainpy/math/jitconn.py | 22 +- brainpy/math/op_register.py | 29 +- brainpy/math/sparse.py | 16 +- brainpy/math/tifunc.py | 53 +- 19 files changed, 2204 insertions(+), 2089 deletions(-) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 1f5b1db6d..a132135cc 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -1,1372 +1,1388 @@ -# -*- coding: utf-8 -*- -from functools import partial -from typing import Optional - -from jax import vmap, jit, numpy as jnp -import numpy as np -from numba import njit - -import brainpy.math as bm -from brainpy.errors import ConnectorError -from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed -from brainpy._src.tools.package import SUPPORT_NUMBA -from .base import * - -__all__ = [ - 'FixedProb', - 'FixedPreNum', - 'FixedPostNum', - 'FixedTotalNum', - 'GaussianProb', - 'ProbDist', - - 'SmallWorld', - 'ScaleFreeBA', - 'ScaleFreeBADual', - 'PowerLaw', -] - - -class FixedProb(TwoEndConnector): - """Connect the post-synaptic neurons with fixed probability. - - Parameters - ---------- - prob: float - The conn probability. - pre_ratio: float - The ratio of pre-synaptic neurons to connect. - include_self : bool - Whether create (i, i) conn? - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - seed : optional, int - Seed the random generator. - """ - - def __init__(self, - prob, - pre_ratio=1., - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedProb, self).__init__(**kwargs) - assert 0. <= prob <= 1. - assert 0. <= pre_ratio <= 1. - self.prob = prob - self.pre_ratio = pre_ratio - self.include_self = include_self - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self._jaxrand = bm.random.default_rng(self.seed) - self._nprand = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' - f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' - f'seed={self.seed})') - - def _iii(self): - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - - if self.pre_ratio < 1.: - pre_num_to_select = int(self.pre_num * self.pre_ratio) - pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) - else: - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - - post_num_total = self.post_num - post_num_to_select = int(self.post_num * self.prob) - - if self.allow_multi_conn: - selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self._nprand.randint(0, int(1e8))) - else: - rng = self._nprand - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._iii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) - - def build_mat(self): - if self.pre_ratio < 1.: - pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state - else: - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) - mat = bm.asarray(mat) - if not self.include_self: - bm.fill_diagonal(mat, False) - return mat.astype(MAT_DTYPE) - - -class FixedTotalNum(TwoEndConnector): - """Connect the synaptic neurons with fixed total number. - - Parameters - ---------- - num : float,int - The conn total number. - allow_multi_conn : bool, optional - Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. - seed: int, optional - The random number seed. - """ - - def __init__(self, - num, - allow_multi_conn=False, - seed=None, **kwargs): - super().__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) - - def build_coo(self): - mat_element_num = self.pre_num * self.post_num - if self.num > mat_element_num: - raise ConnectorError(f'"num" must be smaller than "all2all num", ' - f'but got {self.num} > {mat_element_num}') - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) - selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) - else: - index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) - selected_pre_ids = index // self.post_num - selected_post_ids = index % self.post_num - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' - - -class FixedNum(TwoEndConnector): - def __init__(self, - num, - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedNum, self).__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.include_self = include_self - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' - - -class FixedPreNum(FixedNum): - """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. - - Parameters - ---------- - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def build_coo(self): - if isinstance(self.num, int) and self.num > self.pre_num: - raise ConnectorError(f'"num" must be smaller than "pre_num", ' - f'but got {self.num} > {self.pre_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num - pre_num_total = self.pre_num - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(post_num_total): - posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) - return posts - - selected_pre_ids = jnp.asarray(single_conn()) - - post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select - if not self.include_self: - true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) - post_nums -= jnp.sum(true_ids, axis=1) - selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_pre_ids = selected_pre_ids.flatten() - selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - -class FixedPostNum(FixedNum): - """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. - - Parameters - ---------- - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def _ii(self): - if isinstance(self.num, int) and self.num > self.post_num: - raise ConnectorError(f'"num" must be smaller than "post_num", ' - f'but got {self.num} > {self.post_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._ii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) - -@jit -@partial(vmap, in_axes=(0, None, None)) -def gaussian_prob_dist_cal1(i_value, post_values, sigma): - dists = jnp.abs(i_value - post_values) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) - -@jit -@partial(vmap, in_axes=(0, None, None, None)) -def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): - dists = jnp.abs(i_value - post_values) - dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) - - -class GaussianProb(OneEndConnector): - r"""Builds a Gaussian connectivity pattern within a population of neurons, - where the connection probability decay according to the gaussian function. - - Specifically, for any pair of neurons :math:`(i, j)`, - - .. math:: - - p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) - - where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. - - Parameters - ---------- - sigma : float - Width of the Gaussian function. - encoding_values : optional, list, tuple, int, float - The value ranges to encode for neurons at each axis. - - - If `values` is not provided, the neuron only encodes each positional - information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is - the index in the high-dimensional space. - - If `values` is a single tuple/list of int/float, neurons at each dimension - will encode the same range of values. For example, ``values=(0, np.pi)``, - neurons at each dimension will encode a continuous value space ``[0, np.pi]``. - - If `values` is a tuple/list of list/tuple, it means the value space will be - different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. - - periodic_boundary : bool - Whether the neuron encode the value space with the periodic boundary. - normalize : bool - Whether normalize the connection probability . - include_self : bool - Whether create the connection at the same position. - seed : int - The random seed. - """ - - def __init__( - self, - sigma: float, - encoding_values: Optional[np.ndarray] = None, - normalize: bool = True, - include_self: bool = True, - periodic_boundary: bool = False, - seed: int = None, - **kwargs - ): - super(GaussianProb, self).__init__(**kwargs) - self.sigma = sigma - self.encoding_values = encoding_values - self.normalize = normalize - self.include_self = include_self - self.periodic_boundary = periodic_boundary - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(sigma={self.sigma}, ' - f'normalize={self.normalize}, ' - f'periodic_boundary={self.periodic_boundary}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - self.rng = np.random.RandomState(self.seed) - # value range to encode - if self.encoding_values is None: - value_ranges = tuple([(0, s) for s in self.pre_size]) - elif isinstance(self.encoding_values, (tuple, list)): - if len(self.encoding_values) == 0: - raise ConnectorError(f'encoding_values has a length of 0.') - elif isinstance(self.encoding_values[0], (int, float)): - assert len(self.encoding_values) == 2 - assert self.encoding_values[0] < self.encoding_values[1] - value_ranges = tuple([self.encoding_values for _ in self.pre_size]) - elif isinstance(self.encoding_values[0], (tuple, list)): - if len(self.encoding_values) != len(self.pre_size): - raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' - f'the encoded values provided only has {len(self.encoding_values)}-D. ' - f'Error in {str(self)}.') - for v in self.encoding_values: - assert isinstance(v[0], (int, float)) - assert len(v) == 2 - value_ranges = tuple(self.encoding_values) - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - - # values - values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] - # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) - post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) - value_sizes = np.array([v[1] - v[0] for v in value_ranges]) - if value_sizes.ndim < post_values.ndim: - value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - - # probability of connections - if isOptimized: - i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) - for i in range(self.pre_num): - list_index = i - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - i_value_list[list_index] = i_value - - if self.periodic_boundary: - prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) - else: - prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) - else: - prob_mat = [] - for i in range(self.pre_num): - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - # distances - dists = np.abs(i_value - post_values) - if self.periodic_boundary: - dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) - exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) - prob_mat.append(exp_dists) - prob_mat = np.stack(prob_mat) - - if self.normalize: - prob_mat /= prob_mat.max() - - # connectivity - conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) - if not self.include_self: - np.fill_diagonal(conn_mat, False) - return conn_mat - - -class SmallWorld(TwoEndConnector): - """Build a Watts–Strogatz small-world graph. - - Parameters - ---------- - num_neighbor : int - Each node is joined with its `k` nearest neighbors in a ring - topology. - prob : float - The probability of rewiring each edge - directed : bool - Whether the graph is a directed graph. - include_self : bool - Whether include the node self. - - Notes - ----- - First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is - joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors - if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as - follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with - :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new - edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. - - References - ---------- - .. [1] Duncan J. Watts and Steven H. Strogatz, - Collective dynamics of small-world networks, - Nature, 393, pp. 440--442, 1998. - """ - - def __init__( - self, - num_neighbor, - prob, - directed=False, - include_self=False, - seed=None, - **kwargs - ): - super(SmallWorld, self).__init__(**kwargs) - self.prob = prob - self.directed = directed - self.num_neighbor = num_neighbor - self.include_self = include_self - - self.seed = format_seed(seed) - self.rng = np.random.RandomState(seed=self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _smallworld_rewire(i, all_j): - if rng.random(1) < prob: - non_connected = np.where(np.logical_not(all_j))[0] - if len(non_connected) <= 1: - return -1 - # Enforce no self-loops or multiple edges - w = rng.choice(non_connected) - while (not include_self) and w == i: - # non_connected.remove(w) - w = rng.choice(non_connected) - return w - else: - return -1 - - self._connect = numba_jit(_smallworld_rewire) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, ' - f'directed={self.directed}, ' - f'num_neighbor={self.num_neighbor}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_conn(self): - assert self.pre_size == self.post_size - - # seed - self.seed = self.rng.randint(1, int(1e7)) - numba_seed(self.seed) - - if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): - num_node = self.pre_num - - if self.num_neighbor > num_node: - raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") - # If k == n, the graph is complete not Watts-Strogatz - if self.num_neighbor == num_node: - conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) - else: - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 - # connect each node to k/2 neighbors - for j in range(1, self.num_neighbor // 2 + 1): - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - conn[nodes, targets] = True - conn[targets, nodes] = True - - # rewire edges from each node - # loop over all nodes in order (label) and neighbors in order (distance) - # no self loops or multiple edges allowed - for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - if self.directed: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(prob=self.prob, i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[u, w] = True - w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) - if w != -1: - conn[v, u] = False - conn[w, u] = True - else: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[v, u] = False - conn[u, w] = True - conn[w, u] = True - # conn = np.asarray(conn, dtype=MAT_DTYPE) - else: - raise ConnectorError('Currently only support 1D ring connection.') - - return 'mat', conn - - -# def _random_subset(seq, m, rng): -# """Return m unique elements from seq. -# -# This differs from random.sample which can return repeated -# elements if seq holds repeated elements. -# -# Note: rng is a random.Random or numpy.random.RandomState instance. -# """ -# targets = set() -# while len(targets) < m: -# x = rng.choice(seq) -# targets.add(x) -# return targets - - -class ScaleFreeBA(TwoEndConnector): - """Build a random graph according to the Barabási–Albert preferential - attachment model. - - A graph of :math:`num\_node` nodes is grown by attaching new nodes each with - :math:`m` edges that are preferentially attached to existing nodes - with high degree. - - Parameters - ---------- - m : int - Number of edges to attach from a new node to existing nodes - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises - ------ - ConnectorError - If `m` does not satisfy ``1 <= m < n``. - - References - ---------- - .. [1] A. L. Barabási and R. Albert "Emergence of scaling in - random networks", Science 286, pp 509-512, 1999. - """ - - def __init__(self, m, directed=False, seed=None, **kwargs): - super(ScaleFreeBA, self).__init__(**kwargs) - self.m = m - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, ' - f'directed={self.directed}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m < 1 or self.m >= num_node: - raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " - f"m < n, while m = {self.m} and n = {num_node}") - - # Add m initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - # Target nodes for new edges - targets = list(range(self.m)) - # List of existing nodes, with nodes repeated once for each adjacent edge - - if not isOptimized: - repeated_nodes = [] - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * self.m) - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), self.m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) - size_repeated_nodes = 0 - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets - size_repeated_nodes += self.m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source - size_repeated_nodes += self.m - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) - source += 1 - - return conn - - -class ScaleFreeBADual(TwoEndConnector): - r"""Build a random graph according to the dual Barabási–Albert preferential - attachment model. - - A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ - edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that - are preferentially attached to existing nodes with high degree. - - Parameters - ---------- - m1 : int - Number of edges to attach from a new node to existing nodes with probability :math:`p` - m2 : int - Number of edges to attach from a new node to existing nodes with probability :math:`1-p` - p : float - The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises - ------ - ConnectorError - If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. - - References - ---------- - .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. - """ - - def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): - super(ScaleFreeBADual, self).__init__(**kwargs) - self.m1 = m1 - self.m2 = m2 - self.p = p - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' - f'p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m1 < 1 or self.m1 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " - f"while m1 = {self.m1} and num_node = {num_node}.") - if self.m2 < 1 or self.m2 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " - f"while m2 = {self.m2} and num_node = {num_node}.") - if self.p < 0 or self.p > 1: - raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") - - # Add max(m1,m2) initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - - if not isOptimized: - # List of existing nodes, with nodes repeated once for each adjacent edge - repeated_nodes = [] - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * m) - # Pick which m to use next time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) - size_repeated_nodes = 0 - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets - size_repeated_nodes += m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source - size_repeated_nodes += m - # Pick which m to use next time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) - source += 1 - - return conn - - -class PowerLaw(TwoEndConnector): - """Holme and Kim algorithm for growing graphs with powerlaw - degree distribution and approximate average clustering. - - Parameters - ---------- - m : int - the number of random edges to add for each new node - p : float, - Probability of adding a triangle after adding a random edge - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Notes - ----- - The average clustering has a hard time getting above a certain - cutoff that depends on :math:`m`. This cutoff is often quite low. The - transitivity (fraction of triangles to possible triangles) seems to - decrease with network size. - - It is essentially the Barabási–Albert (BA) growth model with an - extra step that each random edge is followed by a chance of - making an edge to one of its neighbors too (and thus a triangle). - - This algorithm improves on BA in the sense that it enables a - higher average clustering to be attained if desired. - - It seems possible to have a disconnected graph with this algorithm - since the initial :math:`m` nodes may not be all linked to a new node - on the first iteration like the BA model. - - Raises - ------ - ConnectorError - If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not - satisfy :math:`0 <= p <= 1`. - - References - ---------- - .. [1] P. Holme and B. J. Kim, - "Growing scale-free networks with tunable clustering", - Phys. Rev. E, 65, 026107, 2002. - """ - - def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): - super(PowerLaw, self).__init__(**kwargs) - self.m = m - self.p = p - if self.p > 1 or self.p < 0: - raise ConnectorError(f"p must be in [0,1], while p={self.p}") - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - num_node = self.pre_num - if self.m < 1 or num_node < self.m: - raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids]) - size = np.prod(pre_size) - - for i in range(size): - pre_pos = np.asarray([p[i] for p in pre_ids]) - pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) - connected_pres.extend(pres) - connected_posts.extend(posts) - return np.asarray(connected_pres), np.asarray(connected_posts) +# -*- coding: utf-8 -*- +from functools import partial +from typing import Optional + +from jax import vmap, jit, numpy as jnp +import numpy as np + +import brainpy.math as bm +from brainpy.errors import ConnectorError +from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed +from brainpy._src.tools.package import SUPPORT_NUMBA +from brainpy._src.dependency_check import import_numba_else_None +from .base import * + +numba = import_numba_else_None() + +__all__ = [ + 'FixedProb', + 'FixedPreNum', + 'FixedPostNum', + 'FixedTotalNum', + 'GaussianProb', + 'ProbDist', + + 'SmallWorld', + 'ScaleFreeBA', + 'ScaleFreeBADual', + 'PowerLaw', +] + + +class FixedProb(TwoEndConnector): + """Connect the post-synaptic neurons with fixed probability. + + Parameters + ---------- + prob: float + The conn probability. + pre_ratio: float + The ratio of pre-synaptic neurons to connect. + include_self : bool + Whether create (i, i) conn? + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + seed : optional, int + Seed the random generator. + """ + + def __init__(self, + prob, + pre_ratio=1., + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedProb, self).__init__(**kwargs) + assert 0. <= prob <= 1. + assert 0. <= pre_ratio <= 1. + self.prob = prob + self.pre_ratio = pre_ratio + self.include_self = include_self + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self._jaxrand = bm.random.default_rng(self.seed) + self._nprand = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' + f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' + f'seed={self.seed})') + + def _iii(self): + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + + if self.pre_ratio < 1.: + pre_num_to_select = int(self.pre_num * self.pre_ratio) + pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) + else: + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + + post_num_total = self.post_num + post_num_to_select = int(self.post_num * self.prob) + + if self.allow_multi_conn: + selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self._nprand.randint(0, int(1e8))) + else: + rng = self._nprand + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._iii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) + + def build_mat(self): + if self.pre_ratio < 1.: + pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + else: + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) + mat = bm.asarray(mat) + if not self.include_self: + bm.fill_diagonal(mat, False) + return mat.astype(MAT_DTYPE) + + +class FixedTotalNum(TwoEndConnector): + """Connect the synaptic neurons with fixed total number. + + Parameters + ---------- + num : float,int + The conn total number. + allow_multi_conn : bool, optional + Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. + seed: int, optional + The random number seed. + """ + + def __init__(self, + num, + allow_multi_conn=False, + seed=None, **kwargs): + super().__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) + + def build_coo(self): + mat_element_num = self.pre_num * self.post_num + if self.num > mat_element_num: + raise ConnectorError(f'"num" must be smaller than "all2all num", ' + f'but got {self.num} > {mat_element_num}') + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) + selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) + else: + index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) + selected_pre_ids = index // self.post_num + selected_post_ids = index % self.post_num + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' + + +class FixedNum(TwoEndConnector): + def __init__(self, + num, + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedNum, self).__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.include_self = include_self + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' + + +class FixedPreNum(FixedNum): + """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. + + Parameters + ---------- + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def build_coo(self): + if isinstance(self.num, int) and self.num > self.pre_num: + raise ConnectorError(f'"num" must be smaller than "pre_num", ' + f'but got {self.num} > {self.pre_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num + pre_num_total = self.pre_num + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(post_num_total): + posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) + return posts + + selected_pre_ids = jnp.asarray(single_conn()) + + post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select + if not self.include_self: + true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) + post_nums -= jnp.sum(true_ids, axis=1) + selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_pre_ids = selected_pre_ids.flatten() + selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + +class FixedPostNum(FixedNum): + """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. + + Parameters + ---------- + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def _ii(self): + if isinstance(self.num, int) and self.num > self.post_num: + raise ConnectorError(f'"num" must be smaller than "post_num", ' + f'but got {self.num} > {self.post_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._ii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) + + +@jit +@partial(vmap, in_axes=(0, None, None)) +def gaussian_prob_dist_cal1(i_value, post_values, sigma): + dists = jnp.abs(i_value - post_values) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + + +@jit +@partial(vmap, in_axes=(0, None, None, None)) +def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): + dists = jnp.abs(i_value - post_values) + dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + + +class GaussianProb(OneEndConnector): + r"""Builds a Gaussian connectivity pattern within a population of neurons, + where the connection probability decay according to the gaussian function. + + Specifically, for any pair of neurons :math:`(i, j)`, + + .. math:: + + p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) + + where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. + + Parameters + ---------- + sigma : float + Width of the Gaussian function. + encoding_values : optional, list, tuple, int, float + The value ranges to encode for neurons at each axis. + + - If `values` is not provided, the neuron only encodes each positional + information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is + the index in the high-dimensional space. + - If `values` is a single tuple/list of int/float, neurons at each dimension + will encode the same range of values. For example, ``values=(0, np.pi)``, + neurons at each dimension will encode a continuous value space ``[0, np.pi]``. + - If `values` is a tuple/list of list/tuple, it means the value space will be + different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. + + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. + normalize : bool + Whether normalize the connection probability . + include_self : bool + Whether create the connection at the same position. + seed : int + The random seed. + """ + + def __init__( + self, + sigma: float, + encoding_values: Optional[np.ndarray] = None, + normalize: bool = True, + include_self: bool = True, + periodic_boundary: bool = False, + seed: int = None, + **kwargs + ): + super(GaussianProb, self).__init__(**kwargs) + self.sigma = sigma + self.encoding_values = encoding_values + self.normalize = normalize + self.include_self = include_self + self.periodic_boundary = periodic_boundary + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(sigma={self.sigma}, ' + f'normalize={self.normalize}, ' + f'periodic_boundary={self.periodic_boundary}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + self.rng = np.random.RandomState(self.seed) + # value range to encode + if self.encoding_values is None: + value_ranges = tuple([(0, s) for s in self.pre_size]) + elif isinstance(self.encoding_values, (tuple, list)): + if len(self.encoding_values) == 0: + raise ConnectorError(f'encoding_values has a length of 0.') + elif isinstance(self.encoding_values[0], (int, float)): + assert len(self.encoding_values) == 2 + assert self.encoding_values[0] < self.encoding_values[1] + value_ranges = tuple([self.encoding_values for _ in self.pre_size]) + elif isinstance(self.encoding_values[0], (tuple, list)): + if len(self.encoding_values) != len(self.pre_size): + raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' + f'the encoded values provided only has {len(self.encoding_values)}-D. ' + f'Error in {str(self)}.') + for v in self.encoding_values: + assert isinstance(v[0], (int, float)) + assert len(v) == 2 + value_ranges = tuple(self.encoding_values) + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + + # values + values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] + # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) + post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) + value_sizes = np.array([v[1] - v[0] for v in value_ranges]) + if value_sizes.ndim < post_values.ndim: + value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + + # probability of connections + if isOptimized: + i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) + for i in range(self.pre_num): + list_index = i + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + i_value_list[list_index] = i_value + + if self.periodic_boundary: + prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) + else: + prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) + else: + prob_mat = [] + for i in range(self.pre_num): + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + # distances + dists = np.abs(i_value - post_values) + if self.periodic_boundary: + dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) + exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) + prob_mat.append(exp_dists) + prob_mat = np.stack(prob_mat) + + if self.normalize: + prob_mat /= prob_mat.max() + + # connectivity + conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) + if not self.include_self: + np.fill_diagonal(conn_mat, False) + return conn_mat + + +class SmallWorld(TwoEndConnector): + """Build a Watts–Strogatz small-world graph. + + Parameters + ---------- + num_neighbor : int + Each node is joined with its `k` nearest neighbors in a ring + topology. + prob : float + The probability of rewiring each edge + directed : bool + Whether the graph is a directed graph. + include_self : bool + Whether include the node self. + + Notes + ----- + First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is + joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors + if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as + follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with + :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new + edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. + + References + ---------- + .. [1] Duncan J. Watts and Steven H. Strogatz, + Collective dynamics of small-world networks, + Nature, 393, pp. 440--442, 1998. + """ + + def __init__( + self, + num_neighbor, + prob, + directed=False, + include_self=False, + seed=None, + **kwargs + ): + super(SmallWorld, self).__init__(**kwargs) + self.prob = prob + self.directed = directed + self.num_neighbor = num_neighbor + self.include_self = include_self + + self.seed = format_seed(seed) + self.rng = np.random.RandomState(seed=self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _smallworld_rewire(i, all_j): + if rng.random(1) < prob: + non_connected = np.where(np.logical_not(all_j))[0] + if len(non_connected) <= 1: + return -1 + # Enforce no self-loops or multiple edges + w = rng.choice(non_connected) + while (not include_self) and w == i: + # non_connected.remove(w) + w = rng.choice(non_connected) + return w + else: + return -1 + + self._connect = numba_jit(_smallworld_rewire) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, ' + f'directed={self.directed}, ' + f'num_neighbor={self.num_neighbor}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_conn(self): + assert self.pre_size == self.post_size + + # seed + self.seed = self.rng.randint(1, int(1e7)) + numba_seed(self.seed) + + if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): + num_node = self.pre_num + + if self.num_neighbor > num_node: + raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") + # If k == n, the graph is complete not Watts-Strogatz + if self.num_neighbor == num_node: + conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) + else: + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 + # connect each node to k/2 neighbors + for j in range(1, self.num_neighbor // 2 + 1): + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + conn[nodes, targets] = True + conn[targets, nodes] = True + + # rewire edges from each node + # loop over all nodes in order (label) and neighbors in order (distance) + # no self loops or multiple edges allowed + for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + if self.directed: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(prob=self.prob, i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[u, w] = True + w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) + if w != -1: + conn[v, u] = False + conn[w, u] = True + else: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[v, u] = False + conn[u, w] = True + conn[w, u] = True + # conn = np.asarray(conn, dtype=MAT_DTYPE) + else: + raise ConnectorError('Currently only support 1D ring connection.') + + return 'mat', conn + + +# def _random_subset(seq, m, rng): +# """Return m unique elements from seq. +# +# This differs from random.sample which can return repeated +# elements if seq holds repeated elements. +# +# Note: rng is a random.Random or numpy.random.RandomState instance. +# """ +# targets = set() +# while len(targets) < m: +# x = rng.choice(seq) +# targets.add(x) +# return targets + + +class ScaleFreeBA(TwoEndConnector): + """Build a random graph according to the Barabási–Albert preferential + attachment model. + + A graph of :math:`num\_node` nodes is grown by attaching new nodes each with + :math:`m` edges that are preferentially attached to existing nodes + with high degree. + + Parameters + ---------- + m : int + Number of edges to attach from a new node to existing nodes + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises + ------ + ConnectorError + If `m` does not satisfy ``1 <= m < n``. + + References + ---------- + .. [1] A. L. Barabási and R. Albert "Emergence of scaling in + random networks", Science 286, pp 509-512, 1999. + """ + + def __init__(self, m, directed=False, seed=None, **kwargs): + super(ScaleFreeBA, self).__init__(**kwargs) + self.m = m + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, ' + f'directed={self.directed}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m < 1 or self.m >= num_node: + raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " + f"m < n, while m = {self.m} and n = {num_node}") + + # Add m initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + # Target nodes for new edges + targets = list(range(self.m)) + # List of existing nodes, with nodes repeated once for each adjacent edge + + if not isOptimized: + repeated_nodes = [] + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * self.m) + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), self.m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) + size_repeated_nodes = 0 + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets + size_repeated_nodes += self.m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source + size_repeated_nodes += self.m + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) + source += 1 + + return conn + + +class ScaleFreeBADual(TwoEndConnector): + r"""Build a random graph according to the dual Barabási–Albert preferential + attachment model. + + A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ + edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that + are preferentially attached to existing nodes with high degree. + + Parameters + ---------- + m1 : int + Number of edges to attach from a new node to existing nodes with probability :math:`p` + m2 : int + Number of edges to attach from a new node to existing nodes with probability :math:`1-p` + p : float + The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises + ------ + ConnectorError + If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. + + References + ---------- + .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. + """ + + def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): + super(ScaleFreeBADual, self).__init__(**kwargs) + self.m1 = m1 + self.m2 = m2 + self.p = p + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' + f'p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m1 < 1 or self.m1 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " + f"while m1 = {self.m1} and num_node = {num_node}.") + if self.m2 < 1 or self.m2 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " + f"while m2 = {self.m2} and num_node = {num_node}.") + if self.p < 0 or self.p > 1: + raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") + + # Add max(m1,m2) initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + + if not isOptimized: + # List of existing nodes, with nodes repeated once for each adjacent edge + repeated_nodes = [] + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * m) + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) + size_repeated_nodes = 0 + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets + size_repeated_nodes += m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source + size_repeated_nodes += m + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) + source += 1 + + return conn + + +class PowerLaw(TwoEndConnector): + """Holme and Kim algorithm for growing graphs with powerlaw + degree distribution and approximate average clustering. + + Parameters + ---------- + m : int + the number of random edges to add for each new node + p : float, + Probability of adding a triangle after adding a random edge + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Notes + ----- + The average clustering has a hard time getting above a certain + cutoff that depends on :math:`m`. This cutoff is often quite low. The + transitivity (fraction of triangles to possible triangles) seems to + decrease with network size. + + It is essentially the Barabási–Albert (BA) growth model with an + extra step that each random edge is followed by a chance of + making an edge to one of its neighbors too (and thus a triangle). + + This algorithm improves on BA in the sense that it enables a + higher average clustering to be attained if desired. + + It seems possible to have a disconnected graph with this algorithm + since the initial :math:`m` nodes may not be all linked to a new node + on the first iteration like the BA model. + + Raises + ------ + ConnectorError + If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not + satisfy :math:`0 <= p <= 1`. + + References + ---------- + .. [1] P. Holme and B. J. Kim, + "Growing scale-free networks with tunable clustering", + Phys. Rev. E, 65, 026107, 2002. + """ + + def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): + super(PowerLaw, self).__init__(**kwargs) + self.m = m + self.p = p + if self.p > 1 or self.p < 0: + raise ConnectorError(f"p must be in [0,1], while p={self.p}") + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + num_node = self.pre_num + if self.m < 1 or num_node < self.m: + raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids]) + size = np.prod(pre_size) + + for i in range(size): + pre_pos = np.asarray([p[i] for p in pre_ids]) + pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) + connected_pres.extend(pres) + connected_posts.extend(posts) + return np.asarray(connected_pres), np.asarray(connected_posts) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index f3651b109..715e78c9b 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,87 +1,131 @@ -import os -import sys -from jax.lib import xla_client - -__all__ = [ - 'import_taichi', - 'import_brainpylib_cpu_ops', - 'import_brainpylib_gpu_ops', -] - -_minimal_brainpylib_version = '0.1.10' -_minimal_taichi_version = (1, 7, 0) - -taichi = None -brainpylib_cpu_ops = None -brainpylib_gpu_ops = None - -taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' - f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' - '> pip install taichi==1.7.0') -os.environ["TI_LOG_LEVEL"] = "error" - - -def import_taichi(): - global taichi - if taichi is None: - with open(os.devnull, 'w') as devnull: - old_stdout = sys.stdout - sys.stdout = devnull - try: - import taichi as taichi # noqa - except ModuleNotFoundError: - raise ModuleNotFoundError(taichi_install_info) - finally: - sys.stdout = old_stdout - - if taichi.__version__ != _minimal_taichi_version: - raise RuntimeError(taichi_install_info) - return taichi - - -def is_brainpylib_gpu_installed(): - return False if brainpylib_gpu_ops is None else True - - -def import_brainpylib_cpu_ops(): - global brainpylib_cpu_ops - if brainpylib_cpu_ops is None: - try: - from brainpylib import cpu_ops as brainpylib_cpu_ops - - for _name, _value in brainpylib_cpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="cpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_cpu_ops - - -def import_brainpylib_gpu_ops(): - global brainpylib_gpu_ops - if brainpylib_gpu_ops is None: - try: - from brainpylib import gpu_ops as brainpylib_gpu_ops - - for _name, _value in brainpylib_gpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="gpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install GPU version of brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_gpu_ops +import os +import sys +from jax.lib import xla_client + +__all__ = [ + 'import_taichi', + 'import_taichi_else_None', + 'import_numba', + 'import_numba_else_None', + 'import_brainpylib_cpu_ops', + 'import_brainpylib_gpu_ops', +] + +_minimal_brainpylib_version = '0.1.10' +_minimal_taichi_version = (1, 7, 0) + +taichi = None +numba = None +brainpylib_cpu_ops = None +brainpylib_gpu_ops = None + +taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' + f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' + '> pip install taichi==1.7.0') +os.environ["TI_LOG_LEVEL"] = "error" + + +def import_taichi(): + global taichi + if taichi is None: + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + try: + import taichi as taichi # noqa + except ModuleNotFoundError: + raise ModuleNotFoundError(taichi_install_info) + finally: + sys.stdout = old_stdout + + if taichi.__version__ != _minimal_taichi_version: + raise RuntimeError(taichi_install_info) + return taichi + + +def import_taichi_else_None(): + global taichi + if taichi is None: + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + try: + import taichi as taichi # noqa + except: + return None + finally: + sys.stdout = old_stdout + + if taichi.__version__ != _minimal_taichi_version: + raise RuntimeError(taichi_install_info) + return taichi + + +def import_numba(): + global numba + if numba is None: + try: + import numba as numba + except ModuleNotFoundError: + raise ModuleNotFoundError('We need numba. Please install numba by pip . \n' + '> pip install numba' + ) + return numba + + +def import_numba_else_None(): + global numba + if numba is None: + try: + import numba as numba + except: + return None + return numba + + +def is_brainpylib_gpu_installed(): + return False if brainpylib_gpu_ops is None else True + + +def import_brainpylib_cpu_ops(): + global brainpylib_cpu_ops + if brainpylib_cpu_ops is None: + try: + from brainpylib import cpu_ops as brainpylib_cpu_ops + + for _name, _value in brainpylib_cpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="cpu") + + import brainpylib + if brainpylib.__version__ < _minimal_brainpylib_version: + raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') + if hasattr(brainpylib, 'check_brainpy_version'): + brainpylib.check_brainpy_version() + + except ImportError: + raise ImportError('Please install brainpylib. \n' + 'See https://brainpy.readthedocs.io for installation instructions.') + + return brainpylib_cpu_ops + + +def import_brainpylib_gpu_ops(): + global brainpylib_gpu_ops + if brainpylib_gpu_ops is None: + try: + from brainpylib import gpu_ops as brainpylib_gpu_ops + + for _name, _value in brainpylib_gpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="gpu") + + import brainpylib + if brainpylib.__version__ < _minimal_brainpylib_version: + raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') + if hasattr(brainpylib, 'check_brainpy_version'): + brainpylib.check_brainpy_version() + + except ImportError: + raise ImportError('Please install GPU version of brainpylib. \n' + 'See https://brainpy.readthedocs.io for installation instructions.') + + return brainpylib_gpu_ops diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 6a37bdcba..c23b6e21f 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -6,7 +6,6 @@ import jax import jax.numpy as jnp -import numba import numpy as np from brainpy import math as bm @@ -14,14 +13,15 @@ from brainpy._src.context import share from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi_else_None, import_numba_else_None, import_taichi, import_numba from brainpy.check import is_initializer from brainpy.connect import csr2csc from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding -ti = import_taichi() +ti = import_taichi_else_None() +numba = import_numba_else_None() __all__ = [ 'Dense', 'Linear', @@ -246,56 +246,58 @@ def update(self, x): # if spike[i]: # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) -@ti.kernel -def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: +dense_on_pre_prim = None +if ti is not None: + @ti.kernel + def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[1]): + new_value = out_w[i, j] + trace0 + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + + + @ti.kernel + def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[1]): + new_value = out_w[i, j] + trace0 + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: out_w[i, j] = new_value - - -@ti.kernel -def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: - out_w[i, j] = new_value -dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, - gpu_kernel=_gpu_dense_on_pre) + dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, + gpu_kernel=_gpu_dense_on_pre) def dense_on_pre(weight, spike, trace, w_min, w_max): @@ -306,6 +308,8 @@ def dense_on_pre(weight, spike, trace, w_min, w_max): trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + if dense_on_pre_prim is None: + import_taichi() return dense_on_pre_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] @@ -317,54 +321,56 @@ def dense_on_pre(weight, spike, trace, w_min, w_max): # if spike[i]: # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) -@ti.kernel -def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - -@ti.kernel -def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - -dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, - gpu_kernel=_gpu_dense_on_post) +dense_on_post_prim = None +if ti is not None: + @ti.kernel + def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[0]): + new_value = out_w[j, i] + trace0 + if new_value < w_min0: + out_w[j, i] = w_min0 + elif new_value > w_max0: + out_w[j, i] = w_max0 + else: + out_w[j, i] = new_value + + @ti.kernel + def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): + out_w[i, j] = weight[i, j] + for i in range(spike.shape[0]): + if spike[i]: + for j in range(out_w.shape[0]): + new_value = out_w[j, i] + trace0 + if new_value < w_min0: + out_w[j, i] = w_min0 + elif new_value > w_max0: + out_w[j, i] = w_max0 + else: + out_w[j, i] = new_value + + dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, + gpu_kernel=_gpu_dense_on_post) def dense_on_post(weight, spike, trace, w_min, w_max): @@ -375,6 +381,8 @@ def dense_on_post(weight, spike, trace, w_min, w_max): trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + if dense_on_post_prim is None: + import_taichi() return dense_on_post_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] @@ -756,49 +764,50 @@ def _batch_csrmv(self, x): # # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) # out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - -@ti.kernel -def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) -@ti.kernel -def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) - - -csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update, - gpu_kernel=_gpu_csr_on_pre_update) +csr_on_pre_update_prim = None +if ti is not None: + @ti.kernel + def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i in range(out_w.shape[0]): + out_w[i] = w[i] + for i in range(spike.shape[0]): + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = indices[k] + out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + @ti.kernel + def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + spike: ti.types.ndarray(ndim=1), + trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=1)): + trace0 = trace[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + for i in range(out_w.shape[0]): + out_w[i] = w[i] + for i in range(spike.shape[0]): + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = indices[k] + out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + + + csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update, + gpu_kernel=_gpu_csr_on_pre_update) def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): @@ -809,23 +818,27 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + if csr_on_pre_update_prim is None: + import_taichi() return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # post id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = post_ids[k] # pre id - l = w_ids[k] # syn id - out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) +csc_on_pre_update_prim = None +if numba is not None: + @numba.njit(nogil=True, fastmath=True, parallel=False) + def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + out_w[:] = w + w_min = w_min[()] + w_max = w_max[()] + for i in numba.prange(spike.shape[0]): # post id + if spike[i]: + for k in range(indptr[i], indptr[i + 1]): + j = post_ids[k] # pre id + l = w_ids[k] # syn id + out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) -csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) + csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): @@ -833,6 +846,8 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_m w_min = -np.inf if w_max is None: w_max = np.inf + if csc_on_pre_update_prim is None: + import_numba() return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index dae0f1bcd..beca46e79 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -1,13 +1,13 @@ import jax.numpy as jnp from jax import config -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi_else_None from .modes import NonBatchingMode from .scales import IdScaling __all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_'] -ti = import_taichi() +ti = import_taichi_else_None() # Default computation mode. mode = NonBatchingMode() @@ -30,7 +30,6 @@ # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 - if ti is not None: # '''Default integer data type in Taichi.''' ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 @@ -40,4 +39,4 @@ else: ti_int = None - ti_float = None \ No newline at end of file + ti_float = None diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 757c19b8d..70bc7a0e6 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -16,9 +16,9 @@ from . import modes from . import scales from . import defaults -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi_else_None -ti = import_taichi() +ti = import_taichi_else_None() __all__ = [ # context manage for environment setting diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index e61dc10cf..a790c05e7 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,3 +1,5 @@ +from brainpy._src.dependency_check import import_taichi_else_None -from ._csr_matvec import * +if import_taichi_else_None() is not None: + from ._csr_matvec import * diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index a79cdc982..ea087c467 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,3 +1,4 @@ - -from ._matvec import * -from ._event_matvec import * \ No newline at end of file +from brainpy._src.dependency_check import import_taichi_else_None +if import_taichi_else_None() is not None: + from ._matvec import * + from ._event_matvec import * \ No newline at end of file diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 01f77dbca..93666197e 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -1,7 +1,15 @@ - -from .numba_approach import (CustomOpByNumba, - register_op_with_numba, - compile_cpu_signature_with_numba) -from .taichi_aot_based import clean_caches, check_kernels_count -from .base import XLACustomOp -from .utils import register_general_batching +from brainpy._src.dependency_check import import_numba_else_None, import_taichi_else_None + +numba = import_numba_else_None() +taichi = import_taichi_else_None() + +if numba is not None: + from .numba_approach import (CustomOpByNumba, + register_op_with_numba, + compile_cpu_signature_with_numba) + from .base import XLACustomOp + from .utils import register_general_batching +if taichi is not None: + from .taichi_aot_based import clean_caches, check_kernels_count + from .base import XLACustomOp + from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index cc2ce5b4c..13d4f66e7 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -1,200 +1,205 @@ # -*- coding: utf-8 -*- -import warnings from functools import partial from typing import Callable from typing import Union, Sequence -import numba import jax from jax.interpreters import xla, batching, ad from jax.tree_util import tree_map -from numba.core.dispatcher import Dispatcher +from brainpy._src.dependency_check import import_numba_else_None from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject -from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba - -__all__ = [ - 'CustomOpByNumba', - 'register_op_with_numba', - 'compile_cpu_signature_with_numba', -] - - -class CustomOpByNumba(BrainPyObject): - """Creating a XLA custom call operator with Numba JIT on CPU backend. - - Parameters - ---------- - name: str - The name of operator. - eval_shape: callable - The function to evaluate the shape and dtype of the output according to the input. - This function should receive the abstract information of inputs, and return the - abstract information of the outputs. For example: - - >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): - >>> return out1_info, out2_info - con_compute: callable - The function to make the concrete computation. This function receives inputs, - and returns outputs. For example: - - >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...): - >>> pass - """ - - def __init__( - self, - eval_shape: Callable = None, - con_compute: Callable = None, - name: str = None, + + +numba = import_numba_else_None() + +if numba is not None: + from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba + from numba.core.dispatcher import Dispatcher + + __all__ = [ + 'CustomOpByNumba', + 'register_op_with_numba', + 'compile_cpu_signature_with_numba', + ] + + + class CustomOpByNumba(BrainPyObject): + """Creating a XLA custom call operator with Numba JIT on CPU backend. + + Parameters + ---------- + name: str + The name of operator. + eval_shape: callable + The function to evaluate the shape and dtype of the output according to the input. + This function should receive the abstract information of inputs, and return the + abstract information of the outputs. For example: + + >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): + >>> return out1_info, out2_info + con_compute: callable + The function to make the concrete computation. This function receives inputs, + and returns outputs. For example: + + >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...): + >>> pass + """ + + def __init__( + self, + eval_shape: Callable = None, + con_compute: Callable = None, + name: str = None, + batching_translation: Callable = None, + jvp_translation: Callable = None, + transpose_translation: Callable = None, + multiple_results: bool = True, + ): + super().__init__(name=name) + + # abstract evaluation function + if eval_shape is None: + raise ValueError('Must provide "eval_shape" for abstract evaluation.') + + # cpu function + cpu_func = con_compute + + # register OP + self.op = register_op_with_numba( + self.name, + cpu_func=cpu_func, + out_shapes=eval_shape, + batching_translation=batching_translation, + jvp_translation=jvp_translation, + transpose_translation=transpose_translation, + multiple_results=multiple_results, + ) + + def __call__(self, *args, **kwargs): + args = tree_map(lambda a: a.value if isinstance(a, Array) else a, + args, is_leaf=lambda a: isinstance(a, Array)) + kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a, + kwargs, is_leaf=lambda a: isinstance(a, Array)) + res = self.op.bind(*args, **kwargs) + return res + + + def register_op_with_numba( + op_name: str, + cpu_func: Callable, + out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], + gpu_func_translation: Callable = None, batching_translation: Callable = None, jvp_translation: Callable = None, transpose_translation: Callable = None, - multiple_results: bool = True, + multiple_results: bool = False, ): - super().__init__(name=name) - - # abstract evaluation function - if eval_shape is None: - raise ValueError('Must provide "eval_shape" for abstract evaluation.') + """ + Converting the numba-jitted function in a Jax/XLA compatible primitive. + + Parameters + ---------- + op_name: str + Name of the operators. + + cpu_func: Callable + A callable numba-jitted function or pure function (can be lambda function) running on CPU. + + out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None + Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or + a sequence of `ShapedArray`. If it is a function, it takes as input the argument + shapes and dtypes and should return correct output shapes of `ShapedArray`. + + gpu_func_translation: Callable + A callable cuda-jitted kernel running on GPU. + + batching_translation: Callable + The batching translation for the primitive. + + jvp_translation: Callable + The forward autodiff translation rule. + + transpose_translation: Callable + The backward autodiff translation rule. + + multiple_results: bool + Whether the primitive returns multiple results. Default is False. + + Returns + ------- + op: core.Primitive + A JAX Primitive object. + """ + + if jax.__version__ > '0.4.23': + raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are ' + f'only supported in JAX version <= 0.4.23. \n' + f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. ' + f'For more information, please refer to the documentation: ' + f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') + + if out_shapes is None: + raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' + 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' + 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') + + prim = jax.core.Primitive(op_name) + prim.multiple_results = multiple_results + + # user defined function + if not isinstance(cpu_func, Dispatcher): + cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) + + # output shape evaluation function + def abs_eval_rule(*input_shapes, **info): + if callable(out_shapes): + shapes = out_shapes(*input_shapes, **info) + else: + shapes = out_shapes + + if isinstance(shapes, jax.core.ShapedArray): + assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." + elif isinstance(shapes, (tuple, list)): + assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." + for elem in shapes: + if not isinstance(elem, jax.core.ShapedArray): + raise ValueError(f'Elements in "out_shapes" must be instances of ' + f'jax.abstract_arrays.ShapedArray, but we got ' + f'{type(elem)}: {elem}') + else: + raise ValueError(f'Unknown type {type(shapes)}, only ' + f'supports function, ShapedArray or ' + f'list/tuple of ShapedArray.') + return shapes # cpu function - cpu_func = con_compute - - # register OP - self.op = register_op_with_numba( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) - - def __call__(self, *args, **kwargs): - args = tree_map(lambda a: a.value if isinstance(a, Array) else a, - args, is_leaf=lambda a: isinstance(a, Array)) - kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a, - kwargs, is_leaf=lambda a: isinstance(a, Array)) - res = self.op.bind(*args, **kwargs) - return res - - -def register_op_with_numba( - op_name: str, - cpu_func: Callable, - out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], - gpu_func_translation: Callable = None, - batching_translation: Callable = None, - jvp_translation: Callable = None, - transpose_translation: Callable = None, - multiple_results: bool = False, -): - """ - Converting the numba-jitted function in a Jax/XLA compatible primitive. - - Parameters - ---------- - op_name: str - Name of the operators. - - cpu_func: Callable - A callable numba-jitted function or pure function (can be lambda function) running on CPU. - - out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None - Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or - a sequence of `ShapedArray`. If it is a function, it takes as input the argument - shapes and dtypes and should return correct output shapes of `ShapedArray`. - - gpu_func_translation: Callable - A callable cuda-jitted kernel running on GPU. - - batching_translation: Callable - The batching translation for the primitive. - - jvp_translation: Callable - The forward autodiff translation rule. - - transpose_translation: Callable - The backward autodiff translation rule. - - multiple_results: bool - Whether the primitive returns multiple results. Default is False. - - Returns - ------- - op: core.Primitive - A JAX Primitive object. - """ - - if jax.__version__ > '0.4.23': - raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are ' - f'only supported in JAX version <= 0.4.23. \n' - f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. ' - f'For more information, please refer to the documentation: ' - f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') - - if out_shapes is None: - raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' - 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' - 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') - - prim = jax.core.Primitive(op_name) - prim.multiple_results = multiple_results - - # user defined function - if not isinstance(cpu_func, Dispatcher): - cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) - - # output shape evaluation function - def abs_eval_rule(*input_shapes, **info): - if callable(out_shapes): - shapes = out_shapes(*input_shapes, **info) - else: - shapes = out_shapes - - if isinstance(shapes, jax.core.ShapedArray): - assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." - elif isinstance(shapes, (tuple, list)): - assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." - for elem in shapes: - if not isinstance(elem, jax.core.ShapedArray): - raise ValueError(f'Elements in "out_shapes" must be instances of ' - f'jax.abstract_arrays.ShapedArray, but we got ' - f'{type(elem)}: {elem}') - else: - raise ValueError(f'Unknown type {type(shapes)}, only ' - f'supports function, ShapedArray or ' - f'list/tuple of ShapedArray.') - return shapes - - # cpu function - prim.def_abstract_eval(abs_eval_rule) - prim.def_impl(partial(xla.apply_primitive, prim)) - xla.backend_specific_translations['cpu'][prim] = partial(_cpu_translation, - cpu_func, - abs_eval_rule, - multiple_results) - - # gpu function - if gpu_func_translation is not None: - xla.backend_specific_translations['gpu'][prim] = gpu_func_translation - - # batching - if batching_translation is not None: - batching.primitive_batchers[prim] = batching_translation - - # jvp - if jvp_translation is not None: - ad.primitive_jvps[prim] = jvp_translation - - # transpose - if transpose_translation is not None: - ad.primitive_transposes[prim] = transpose_translation - - return prim + prim.def_abstract_eval(abs_eval_rule) + prim.def_impl(partial(xla.apply_primitive, prim)) + xla.backend_specific_translations['cpu'][prim] = partial(_cpu_translation, + cpu_func, + abs_eval_rule, + multiple_results) + + # gpu function + if gpu_func_translation is not None: + xla.backend_specific_translations['gpu'][prim] = gpu_func_translation + + # batching + if batching_translation is not None: + batching.primitive_batchers[prim] = batching_translation + + # jvp + if jvp_translation is not None: + ad.primitive_jvps[prim] = jvp_translation + + # transpose + if transpose_translation is not None: + ad.primitive_transposes[prim] = transpose_translation + return prim +else: + __all__ = [] diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 13974b5b2..02f74a237 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -1,146 +1,155 @@ -# -*- coding: utf-8 -*- - -import ctypes - -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client -from numba import types, carray, cfunc - -__all__ = [ - 'compile_cpu_signature_with_numba' -] - -ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, # void* pointer - ctypes.c_char_p, # const char *name - ctypes.c_void_p, # PyCapsule_Destructor destructor -] -ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - - -def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): - target_name, inputs, input_shapes, xla_output_shapes = \ - compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_shapes, - shape_with_layout=xla_output_shapes, - ) - - -def _cpu_signature( - func, - input_dtypes, - input_shapes, - output_dtypes, - output_shapes, - multiple_results: bool, - debug: bool = False -): - code_scope = dict( - func_to_call=func, - input_shapes=input_shapes, - input_dtypes=input_dtypes, - output_shapes=output_shapes, - output_dtypes=output_dtypes, - carray=carray, - ) - - # inputs - if len(input_shapes) > 1: - args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' - for i in range(len(input_shapes)) - ] - args_in = '(\n ' + "\n ".join(args_in) + '\n )' - else: - args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' - - # outputs - if multiple_results: - args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' - for i in range(len(output_shapes)) - ] - args_out = '(\n ' + "\n ".join(args_out) + '\n )' - else: - args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' - - # function body - code_string = ''' -def xla_cpu_custom_call_target(output_ptrs, input_ptrs): - args_out = {args_out} - args_in = {args_in} - func_to_call(args_out, args_in) - '''.format(args_in=args_in, - args_out=args_out) - if debug: print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - - new_f = code_scope['xla_cpu_custom_call_target'] - if multiple_results: - xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr)))(new_f) - else: - xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) - target_name = xla_c_rule.native_name.encode("ascii") - capsule = ctypes.pythonapi.PyCapsule_New( - xla_c_rule.address, # A CFFI pointer to a function - b"xla._CUSTOM_CALL_TARGET", # A binary string - None # PyCapsule object run at destruction - ) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - return target_name - - -def compile_cpu_signature_with_numba( - c, - func, - abs_eval_fn, - multiple_results, - inputs: tuple, - description: dict = None, -): - input_layouts = [c.get_shape(arg) for arg in inputs] - info_inputs = [] - if description is None: description = dict() - for v in description.values(): - if isinstance(v, (int, float)): - input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) - info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) - elif isinstance(v, (tuple, list)): - v = jnp.asarray(v) - input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) - info_inputs.append(xla_client.ops.Constant(c, v)) - else: - raise TypeError - input_layouts = tuple(input_layouts) - input_dtypes = tuple(shape.element_type() for shape in input_layouts) - input_dimensions = tuple(shape.dimensions() for shape in input_layouts) - output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_layouts[:len(inputs)]), - **description) - if isinstance(output_abstract_arrays, ShapedArray): - output_abstract_arrays = (output_abstract_arrays,) - assert not multiple_results - else: - assert multiple_results - output_shapes = tuple(array.shape for array in output_abstract_arrays) - output_dtypes = tuple(array.dtype for array in output_abstract_arrays) - output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) - target_name = _cpu_signature(func, - input_dtypes, - input_dimensions, - output_dtypes, - output_shapes, - multiple_results, - debug=False) - output_layouts = [xla_client.Shape.array_shape(*arg) - for arg in zip(output_dtypes, output_shapes, output_layouts)] - output_layouts = (xla_client.Shape.tuple_shape(output_layouts) - if multiple_results else - output_layouts[0]) - return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts +# -*- coding: utf-8 -*- + +import ctypes + +from jax import dtypes, numpy as jnp +from jax.core import ShapedArray +from jax.lib import xla_client + +from brainpy._src.dependency_check import import_numba_else_None + +numba = import_numba_else_None() + +if numba is not None: + from numba import types, carray, cfunc + + __all__ = [ + '_cpu_translation', + 'compile_cpu_signature_with_numba', + ] + + ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, # void* pointer + ctypes.c_char_p, # const char *name + ctypes.c_void_p, # PyCapsule_Destructor destructor + ] + ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + + + def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): + target_name, inputs, input_shapes, xla_output_shapes = \ + compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) + return xla_client.ops.CustomCallWithLayout( + c, + target_name, + operands=inputs, + operand_shapes_with_layout=input_shapes, + shape_with_layout=xla_output_shapes, + ) + + + def _cpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + multiple_results: bool, + debug: bool = False + ): + code_scope = dict( + func_to_call=func, + input_shapes=input_shapes, + input_dtypes=input_dtypes, + output_shapes=output_shapes, + output_dtypes=output_dtypes, + carray=carray, + ) + + # inputs + if len(input_shapes) > 1: + args_in = [ + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' + for i in range(len(input_shapes)) + ] + args_in = '(\n ' + "\n ".join(args_in) + '\n )' + else: + args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' + + # outputs + if multiple_results: + args_out = [ + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' + for i in range(len(output_shapes)) + ] + args_out = '(\n ' + "\n ".join(args_out) + '\n )' + else: + args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' + + # function body + code_string = ''' + def xla_cpu_custom_call_target(output_ptrs, input_ptrs): + args_out = {args_out} + args_in = {args_in} + func_to_call(args_out, args_in) + '''.format(args_in=args_in, + args_out=args_out) + if debug: print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + + new_f = code_scope['xla_cpu_custom_call_target'] + if multiple_results: + xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), + types.CPointer(types.voidptr)))(new_f) + else: + xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) + target_name = xla_c_rule.native_name.encode("ascii") + capsule = ctypes.pythonapi.PyCapsule_New( + xla_c_rule.address, # A CFFI pointer to a function + b"xla._CUSTOM_CALL_TARGET", # A binary string + None # PyCapsule object run at destruction + ) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + return target_name + + + def compile_cpu_signature_with_numba( + c, + func, + abs_eval_fn, + multiple_results, + inputs: tuple, + description: dict = None, + ): + input_layouts = [c.get_shape(arg) for arg in inputs] + info_inputs = [] + if description is None: description = dict() + for v in description.values(): + if isinstance(v, (int, float)): + input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) + info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) + elif isinstance(v, (tuple, list)): + v = jnp.asarray(v) + input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) + info_inputs.append(xla_client.ops.Constant(c, v)) + else: + raise TypeError + input_layouts = tuple(input_layouts) + input_dtypes = tuple(shape.element_type() for shape in input_layouts) + input_dimensions = tuple(shape.dimensions() for shape in input_layouts) + output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) + for shape in input_layouts[:len(inputs)]), + **description) + if isinstance(output_abstract_arrays, ShapedArray): + output_abstract_arrays = (output_abstract_arrays,) + assert not multiple_results + else: + assert multiple_results + output_shapes = tuple(array.shape for array in output_abstract_arrays) + output_dtypes = tuple(array.dtype for array in output_abstract_arrays) + output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) + target_name = _cpu_signature(func, + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes, + multiple_results, + debug=False) + output_layouts = [xla_client.Shape.array_shape(*arg) + for arg in zip(output_dtypes, output_shapes, output_layouts)] + output_layouts = (xla_client.Shape.tuple_shape(output_layouts) + if multiple_results else + output_layouts[0]) + return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts +else: + __all__ = [] diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index 6c13ac19a..8a522ccb7 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,9 +1,13 @@ +from brainpy._src.dependency_check import import_taichi_else_None, import_numba_else_None # from ._coo_mv import * # from ._bsr_mv import * -from ._csr_mv import * -from ._utils import * -from ._bsr_mm import * +if import_taichi_else_None() is not None: + from ._csr_mv import * + from ._utils import * +if import_numba_else_None() is not None: + from ._bsr_mm import * + from ._jax_prim import * diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 27f10f4b9..73eb48dcb 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -1,25 +1,18 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Union, Tuple -import jax -import numba -import numpy as np import brainpy.math as bm -from jax import core, dtypes +import jax from jax import numpy as jnp -from jax.interpreters import ad, mlir, xla -from jax.lib import xla_client from jax.experimental.sparse import csr -from jaxlib import gpu_sparse +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, +from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import PackageMissingError @@ -79,6 +72,7 @@ def csrmv( """ return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) + ### TAICHI ### def csrmv_taichi( @@ -145,6 +139,7 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] + def raw_csrmv_taichi( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -155,7 +150,7 @@ def raw_csrmv_taichi( transpose: bool = False, ): if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError(name='taichi', purpose='customized operators') out_shape = shape[1] if transpose else shape[0] if data.shape[0] != 1: if bm.get_platform() == 'gpu': @@ -200,10 +195,10 @@ def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), @ti.kernel def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): ti.loop_config(serialize=True) for row_i in range(row_ptr.shape[0] - 1): for j in range(row_ptr[row_i], row_ptr[row_i + 1]): @@ -227,10 +222,10 @@ def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), @ti.kernel def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): # ti.loop_config(serialize=True) for row_i in range(row_ptr.shape[0] - 1): r = 0. @@ -243,7 +238,6 @@ def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), # GPU operators # ------------- - @ti.kernel def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), @@ -282,10 +276,10 @@ def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), @ti.kernel def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): for i in range((row_ptr.shape[0] - 1) * 32): row_i = i >> 5 index = i & 31 @@ -298,10 +292,10 @@ def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), @ti.kernel def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): for i in range((row_ptr.shape[0] - 1) * 32): row_i = i >> 5 index = i & 31 @@ -362,11 +356,11 @@ def _define_op(cpu_kernel, gpu_kernel): # transpose heter _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) + gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) # no transpose heter _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) + gpu_kernel=_sparse_csr_matvec_heter_gpu) # heter cusparse _csr_matvec_cusparse_p = csr.csr_matvec_p diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py index 928cb345a..d5c1c0399 100644 --- a/brainpy/_src/math/tifunc.py +++ b/brainpy/_src/math/tifunc.py @@ -1,7 +1,7 @@ -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi_else_None from . import defaults -ti = import_taichi() +ti = import_taichi_else_None() if ti is not None: diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 8bec65599..7126bbaac 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -51,47 +51,55 @@ spike_with_mg_grad as spike_with_mg_grad, ) +from brainpy._src.dependency_check import import_taichi_else_None from brainpy._src.math import defaults from brainpy._src.deprecations import deprecation_getattr -__deprecations = { - "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", - sparse.seg_matmul), - 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_homo instead.", - jitconn.event_mv_prob_homo), - 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", - jitconn.event_mv_prob_uniform), - 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_normal instead.", - jitconn.event_mv_prob_normal), - 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_homo instead.", - jitconn.mv_prob_homo), - 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_uniform instead.", - jitconn.mv_prob_uniform), - 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_normal instead.", - jitconn.mv_prob_normal), - 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " - "Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " - "Use brainpy.math.sparse.coo_to_csr instead.", - sparse.coo_to_csr), - 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " - "Use brainpy.math.sparse.csr_to_coo instead.", - sparse.csr_to_coo), - 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " - "Use brainpy.math.sparse.csr_to_dense instead.", - sparse.csr_to_dense), - 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " - "Use brainpy.math.event.csr_to_dense instead.", - event.csrmv), -} + +if import_taichi_else_None() is not None: + __deprecations = { + "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", + sparse.seg_matmul), + 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_homo instead.", + jitconn.event_mv_prob_homo), + 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", + jitconn.event_mv_prob_uniform), + 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_normal instead.", + jitconn.event_mv_prob_normal), + 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_homo instead.", + jitconn.mv_prob_homo), + 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_uniform instead.", + jitconn.mv_prob_uniform), + 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_normal instead.", + jitconn.mv_prob_normal), + 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " + "Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " + "Use brainpy.math.sparse.coo_to_csr instead.", + sparse.coo_to_csr), + 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " + "Use brainpy.math.sparse.csr_to_coo instead.", + sparse.csr_to_coo), + 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " + "Use brainpy.math.sparse.csr_to_dense instead.", + sparse.csr_to_dense), + 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " + "Use brainpy.math.event.csr_to_dense instead.", + event.csrmv), + } +else: + __deprecations = { + "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", + sparse.seg_matmul), + } __getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) del deprecation_getattr, defaults diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 43d89c1b2..4550a69ee 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,4 +1,6 @@ +from brainpy._src.dependency_check import import_taichi_else_None -from brainpy._src.math.event import ( - csrmv as csrmv, -) +if import_taichi_else_None() is not None: + from brainpy._src.math.event import ( + csrmv as csrmv, + ) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 90a028b7e..df91f40f8 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -1,10 +1,12 @@ -from brainpy._src.math.jitconn import ( - event_mv_prob_homo as event_mv_prob_homo, - event_mv_prob_uniform as event_mv_prob_uniform, - event_mv_prob_normal as event_mv_prob_normal, - - mv_prob_homo as mv_prob_homo, - mv_prob_uniform as mv_prob_uniform, - mv_prob_normal as mv_prob_normal, -) - +from brainpy._src.dependency_check import import_taichi_else_None +if import_taichi_else_None() is not None: + from brainpy._src.math.jitconn import ( + event_mv_prob_homo as event_mv_prob_homo, + event_mv_prob_uniform as event_mv_prob_uniform, + event_mv_prob_normal as event_mv_prob_normal, + + mv_prob_homo as mv_prob_homo, + mv_prob_uniform as mv_prob_uniform, + mv_prob_normal as mv_prob_normal, + ) + diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index a48268ef4..fcee2e9a3 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -1,14 +1,15 @@ -# -*- coding: utf-8 -*- - - -from brainpy._src.math.op_register import ( - CustomOpByNumba, - compile_cpu_signature_with_numba, - clean_caches, - check_kernels_count, -) - -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy._src.math.op_register.ad_support import defjvp - - +# -*- coding: utf-8 -*- +from brainpy._src.dependency_check import import_taichi_else_None, import_numba_else_None + +if import_taichi_else_None() is not None and import_numba_else_None() is not None: + from brainpy._src.math.op_register import ( + CustomOpByNumba, + compile_cpu_signature_with_numba, + clean_caches, + check_kernels_count, + ) + + from brainpy._src.math.op_register.base import XLACustomOp + from brainpy._src.math.op_register.ad_support import defjvp + + diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index fbe0acbf2..de2264d26 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,10 +1,14 @@ -from brainpy._src.math.sparse import ( - csrmv, +from brainpy._src.dependency_check import import_taichi_else_None +from brainpy._src.math.sparse import ( seg_matmul, - - csr_to_dense as csr_to_dense, - csr_to_coo as csr_to_coo, - coo_to_csr as coo_to_csr, ) +if import_taichi_else_None() is not None: + from brainpy._src.math.sparse import ( + csrmv, + + csr_to_dense as csr_to_dense, + csr_to_coo as csr_to_coo, + coo_to_csr as coo_to_csr, + ) diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py index 63f3cbe45..8b58cb03d 100644 --- a/brainpy/math/tifunc.py +++ b/brainpy/math/tifunc.py @@ -1,26 +1,27 @@ -# -*- coding: utf-8 -*- - -from brainpy._src.math.tifunc import ( - taichi_lcg_rand, - - # warp reduction primitives - warp_reduce_sum, - - # random number generator - lfsr88_key, - lfsr88_next_key, - lfsr88_normal, - lfsr88_randn, - lfsr88_random_integers, - lfsr88_randint, - lfsr88_uniform, - lfsr88_rand, - lfsr113_key, - lfsr113_next_key, - lfsr113_normal, - lfsr113_randn, - lfsr113_random_integers, - lfsr113_randint, - lfsr113_uniform, - lfsr113_rand -) +# -*- coding: utf-8 -*- +from brainpy._src.dependency_check import import_taichi_else_None +if import_taichi_else_None() is not None: + from brainpy._src.math.tifunc import ( + taichi_lcg_rand, + + # warp reduction primitives + warp_reduce_sum, + + # random number generator + lfsr88_key, + lfsr88_next_key, + lfsr88_normal, + lfsr88_randn, + lfsr88_random_integers, + lfsr88_randint, + lfsr88_uniform, + lfsr88_rand, + lfsr113_key, + lfsr113_next_key, + lfsr113_normal, + lfsr113_randn, + lfsr113_random_integers, + lfsr113_randint, + lfsr113_uniform, + lfsr113_rand + ) From 014408882e755727dba5027f2dc37d1095fe8d80 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 23 Feb 2024 09:18:48 +0800 Subject: [PATCH 05/16] fix --- brainpy/_src/dependency_check.py | 9 ++- brainpy/_src/dnn/linear.py | 4 +- brainpy/_src/math/defaults.py | 4 +- brainpy/_src/math/environment.py | 4 +- brainpy/_src/math/event/__init__.py | 5 +- brainpy/_src/math/event/_csr_matvec.py | 2 +- brainpy/_src/math/jitconn/__init__.py | 6 +- brainpy/_src/math/jitconn/_event_matvec.py | 2 +- brainpy/_src/math/jitconn/_matvec.py | 2 +- brainpy/_src/math/op_register/__init__.py | 23 ++---- brainpy/_src/math/sparse/__init__.py | 11 +-- brainpy/_src/math/sparse/_bsr_mm.py | 4 +- brainpy/_src/math/sparse/_csr_mv.py | 2 +- brainpy/_src/math/sparse/_utils.py | 3 +- brainpy/_src/math/tifunc.py | 4 +- brainpy/math/__init__.py | 85 ++++++++++------------ brainpy/math/event.py | 9 +-- brainpy/math/jitconn.py | 18 ++--- brainpy/math/op_register.py | 21 +++--- brainpy/math/sparse.py | 14 ++-- brainpy/math/tifunc.py | 47 ++++++------ 21 files changed, 123 insertions(+), 156 deletions(-) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 715e78c9b..66408c460 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,10 +1,10 @@ import os import sys + from jax.lib import xla_client __all__ = [ 'import_taichi', - 'import_taichi_else_None', 'import_numba', 'import_numba_else_None', 'import_brainpylib_cpu_ops', @@ -25,7 +25,7 @@ os.environ["TI_LOG_LEVEL"] = "error" -def import_taichi(): +def import_taichi(error_if_not_found=True): global taichi if taichi is None: with open(os.devnull, 'w') as devnull: @@ -34,10 +34,13 @@ def import_taichi(): try: import taichi as taichi # noqa except ModuleNotFoundError: - raise ModuleNotFoundError(taichi_install_info) + if error_if_not_found: + raise ModuleNotFoundError(taichi_install_info) finally: sys.stdout = old_stdout + if taichi is None: + return None if taichi.__version__ != _minimal_taichi_version: raise RuntimeError(taichi_install_info) return taichi diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index c23b6e21f..b85308614 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -13,14 +13,14 @@ from brainpy._src.context import share from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy._src.dependency_check import import_taichi_else_None, import_numba_else_None, import_taichi, import_numba +from brainpy._src.dependency_check import import_numba_else_None, import_taichi, import_numba from brainpy.check import is_initializer from brainpy.connect import csr2csc from brainpy.errors import MathError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding -ti = import_taichi_else_None() +ti = import_taichi(error_if_not_found=False) numba = import_numba_else_None() __all__ = [ diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index beca46e79..6ebe9dc26 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -1,13 +1,13 @@ import jax.numpy as jnp from jax import config -from brainpy._src.dependency_check import import_taichi_else_None +from brainpy._src.dependency_check import import_taichi from .modes import NonBatchingMode from .scales import IdScaling __all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_'] -ti = import_taichi_else_None() +ti = import_taichi(error_if_not_found=False) # Default computation mode. mode = NonBatchingMode() diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 70bc7a0e6..668f837c0 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -16,9 +16,9 @@ from . import modes from . import scales from . import defaults -from brainpy._src.dependency_check import import_taichi_else_None +from brainpy._src.dependency_check import import_taichi -ti = import_taichi_else_None() +ti = import_taichi(error_if_not_found=False) __all__ = [ # context manage for environment setting diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index a790c05e7..bdd3102a3 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,5 +1,2 @@ -from brainpy._src.dependency_check import import_taichi_else_None - -if import_taichi_else_None() is not None: - from ._csr_matvec import * +from ._csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index f4f23fa93..1571ea922 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -28,7 +28,7 @@ 'csrmv' ] -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) def csrmv( diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index ea087c467..6f7cddf6a 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,4 +1,2 @@ -from brainpy._src.dependency_check import import_taichi_else_None -if import_taichi_else_None() is not None: - from ._matvec import * - from ._event_matvec import * \ No newline at end of file +from ._matvec import * +from ._event_matvec import * diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 33ee9f1b5..f389c3773 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -22,7 +22,7 @@ from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'event_mv_prob_homo', diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 84abb9805..4b8fe004a 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -14,7 +14,7 @@ from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'mv_prob_homo', diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 93666197e..ed687eea5 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -1,15 +1,8 @@ -from brainpy._src.dependency_check import import_numba_else_None, import_taichi_else_None - -numba = import_numba_else_None() -taichi = import_taichi_else_None() - -if numba is not None: - from .numba_approach import (CustomOpByNumba, - register_op_with_numba, - compile_cpu_signature_with_numba) - from .base import XLACustomOp - from .utils import register_general_batching -if taichi is not None: - from .taichi_aot_based import clean_caches, check_kernels_count - from .base import XLACustomOp - from .utils import register_general_batching +from .numba_approach import (CustomOpByNumba, + register_op_with_numba, + compile_cpu_signature_with_numba) +from .base import XLACustomOp +from .utils import register_general_batching +from .taichi_aot_based import clean_caches, check_kernels_count +from .base import XLACustomOp +from .utils import register_general_batching diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index 8a522ccb7..d53533247 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,13 +1,8 @@ -from brainpy._src.dependency_check import import_taichi_else_None, import_numba_else_None - # from ._coo_mv import * # from ._bsr_mv import * -if import_taichi_else_None() is not None: - from ._csr_mv import * - from ._utils import * -if import_numba_else_None() is not None: - from ._bsr_mm import * - +from ._csr_mv import * +from ._utils import * +from ._bsr_mm import * from ._jax_prim import * diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 453ab387d..6f9d5378c 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- from functools import partial -from typing import Union, Tuple +from typing import Tuple import jax.lax import numba @@ -11,8 +11,8 @@ from jax.interpreters import ad, xla from jax.lib import xla_client -from brainpy._src.math.interoperability import as_jax from brainpy._src.dependency_check import import_brainpylib_gpu_ops +from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 73eb48dcb..5fdb83443 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -17,7 +17,7 @@ from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'csrmv', diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/_utils.py index a1dc9190e..f5b74e5eb 100644 --- a/brainpy/_src/math/sparse/_utils.py +++ b/brainpy/_src/math/sparse/_utils.py @@ -3,9 +3,8 @@ import warnings from typing import Tuple -import jax import numpy as np -from jax import core, numpy as jnp, dtypes +from jax import core, numpy as jnp from jax.interpreters import mlir, ad from jaxlib import gpu_sparse diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py index d5c1c0399..c54f4d6f7 100644 --- a/brainpy/_src/math/tifunc.py +++ b/brainpy/_src/math/tifunc.py @@ -1,7 +1,7 @@ -from brainpy._src.dependency_check import import_taichi_else_None +from brainpy._src.dependency_check import import_taichi from . import defaults -ti = import_taichi_else_None() +ti = import_taichi(error_if_not_found=False) if ti is not None: diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 7126bbaac..feaa10093 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -51,55 +51,48 @@ spike_with_mg_grad as spike_with_mg_grad, ) -from brainpy._src.dependency_check import import_taichi_else_None from brainpy._src.math import defaults from brainpy._src.deprecations import deprecation_getattr -if import_taichi_else_None() is not None: - __deprecations = { - "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", - sparse.seg_matmul), - 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_homo instead.", - jitconn.event_mv_prob_homo), - 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", - jitconn.event_mv_prob_uniform), - 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_normal instead.", - jitconn.event_mv_prob_normal), - 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_homo instead.", - jitconn.mv_prob_homo), - 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_uniform instead.", - jitconn.mv_prob_uniform), - 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_normal instead.", - jitconn.mv_prob_normal), - 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " - "Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " - "Use brainpy.math.sparse.coo_to_csr instead.", - sparse.coo_to_csr), - 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " - "Use brainpy.math.sparse.csr_to_coo instead.", - sparse.csr_to_coo), - 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " - "Use brainpy.math.sparse.csr_to_dense instead.", - sparse.csr_to_dense), - 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " - "Use brainpy.math.event.csr_to_dense instead.", - event.csrmv), - } -else: - __deprecations = { - "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", - sparse.seg_matmul), - } +__deprecations = { + "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", + sparse.seg_matmul), + 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_homo instead.", + jitconn.event_mv_prob_homo), + 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", + jitconn.event_mv_prob_uniform), + 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_normal instead.", + jitconn.event_mv_prob_normal), + 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_homo instead.", + jitconn.mv_prob_homo), + 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_uniform instead.", + jitconn.mv_prob_uniform), + 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_normal instead.", + jitconn.mv_prob_normal), + 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " + "Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " + "Use brainpy.math.sparse.coo_to_csr instead.", + sparse.coo_to_csr), + 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " + "Use brainpy.math.sparse.csr_to_coo instead.", + sparse.csr_to_coo), + 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " + "Use brainpy.math.sparse.csr_to_dense instead.", + sparse.csr_to_dense), + 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " + "Use brainpy.math.event.csr_to_dense instead.", + event.csrmv), +} __getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) del deprecation_getattr, defaults diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 4550a69ee..02e98b8f3 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,6 +1,3 @@ -from brainpy._src.dependency_check import import_taichi_else_None - -if import_taichi_else_None() is not None: - from brainpy._src.math.event import ( - csrmv as csrmv, - ) +from brainpy._src.math.event import ( + csrmv as csrmv, +) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index df91f40f8..a87d27d58 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -1,12 +1,10 @@ -from brainpy._src.dependency_check import import_taichi_else_None -if import_taichi_else_None() is not None: - from brainpy._src.math.jitconn import ( - event_mv_prob_homo as event_mv_prob_homo, - event_mv_prob_uniform as event_mv_prob_uniform, - event_mv_prob_normal as event_mv_prob_normal, +from brainpy._src.math.jitconn import ( + event_mv_prob_homo as event_mv_prob_homo, + event_mv_prob_uniform as event_mv_prob_uniform, + event_mv_prob_normal as event_mv_prob_normal, - mv_prob_homo as mv_prob_homo, - mv_prob_uniform as mv_prob_uniform, - mv_prob_normal as mv_prob_normal, - ) + mv_prob_homo as mv_prob_homo, + mv_prob_uniform as mv_prob_uniform, + mv_prob_normal as mv_prob_normal, +) diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index fcee2e9a3..c0fcb67ae 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -1,15 +1,12 @@ # -*- coding: utf-8 -*- -from brainpy._src.dependency_check import import_taichi_else_None, import_numba_else_None - -if import_taichi_else_None() is not None and import_numba_else_None() is not None: - from brainpy._src.math.op_register import ( - CustomOpByNumba, - compile_cpu_signature_with_numba, - clean_caches, - check_kernels_count, - ) - - from brainpy._src.math.op_register.base import XLACustomOp - from brainpy._src.math.op_register.ad_support import defjvp +from brainpy._src.math.op_register import ( + CustomOpByNumba, + compile_cpu_signature_with_numba, + clean_caches, + check_kernels_count, +) + +from brainpy._src.math.op_register.base import XLACustomOp +from brainpy._src.math.op_register.ad_support import defjvp diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index de2264d26..aa86679ec 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,14 +1,12 @@ -from brainpy._src.dependency_check import import_taichi_else_None from brainpy._src.math.sparse import ( seg_matmul, ) -if import_taichi_else_None() is not None: - from brainpy._src.math.sparse import ( - csrmv, +from brainpy._src.math.sparse import ( + csrmv, - csr_to_dense as csr_to_dense, - csr_to_coo as csr_to_coo, - coo_to_csr as coo_to_csr, - ) + csr_to_dense as csr_to_dense, + csr_to_coo as csr_to_coo, + coo_to_csr as coo_to_csr, +) diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py index 8b58cb03d..e345c6835 100644 --- a/brainpy/math/tifunc.py +++ b/brainpy/math/tifunc.py @@ -1,27 +1,26 @@ # -*- coding: utf-8 -*- -from brainpy._src.dependency_check import import_taichi_else_None -if import_taichi_else_None() is not None: - from brainpy._src.math.tifunc import ( - taichi_lcg_rand, - # warp reduction primitives - warp_reduce_sum, +from brainpy._src.math.tifunc import ( + taichi_lcg_rand, - # random number generator - lfsr88_key, - lfsr88_next_key, - lfsr88_normal, - lfsr88_randn, - lfsr88_random_integers, - lfsr88_randint, - lfsr88_uniform, - lfsr88_rand, - lfsr113_key, - lfsr113_next_key, - lfsr113_normal, - lfsr113_randn, - lfsr113_random_integers, - lfsr113_randint, - lfsr113_uniform, - lfsr113_rand - ) + # warp reduction primitives + warp_reduce_sum, + + # random number generator + lfsr88_key, + lfsr88_next_key, + lfsr88_normal, + lfsr88_randn, + lfsr88_random_integers, + lfsr88_randint, + lfsr88_uniform, + lfsr88_rand, + lfsr113_key, + lfsr113_next_key, + lfsr113_normal, + lfsr113_randn, + lfsr113_random_integers, + lfsr113_randint, + lfsr113_uniform, + lfsr113_rand +) From 4adf3ed1b82d0fb2b41d721dda511920f1715159 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 23 Feb 2024 13:55:13 +0800 Subject: [PATCH 06/16] Update --- brainpy/_src/connect/random_conn.py | 12 +- brainpy/_src/dependency_check.py | 88 +- brainpy/_src/dnn/linear.py | 26 +- brainpy/_src/dnn/tests/test_linear.py | 439 ++--- brainpy/_src/dnn/tests/test_mode.py | 1605 +++++++++-------- .../_src/dyn/projections/tests/test_STDP.py | 10 +- .../_src/dyn/projections/tests/test_aligns.py | 883 ++++----- .../synapses/tests/test_abstract_synapses.py | 256 +-- .../tests/test_biological_synapses.py | 211 +-- brainpy/_src/math/event/_csr_matvec.py | 4 +- .../_src/math/event/tests/test_event_csrmv.py | 2 +- brainpy/_src/math/jitconn/_event_matvec.py | 8 +- brainpy/_src/math/jitconn/_matvec.py | 376 ++-- .../math/jitconn/tests/test_event_matvec.py | 2 +- .../_src/math/jitconn/tests/test_matvec.py | 2 +- brainpy/_src/math/op_register/base.py | 10 +- .../op_register/numba_approach/__init__.py | 361 ++-- .../numba_approach/cpu_translation.py | 280 ++- brainpy/_src/math/op_register/numba_based.py | 9 +- .../math/op_register/tests/test_ad_support.py | 7 +- .../op_register/tests/test_numba_based.py | 7 +- .../op_register/tests/test_taichi_based.py | 7 +- .../tests/test_taichi_clean_cache.py | 110 +- brainpy/_src/math/sparse/_bsr_mm.py | 98 +- brainpy/_src/math/sparse/_csr_mv.py | 4 +- brainpy/_src/math/sparse/tests/test_csrmv.py | 2 +- brainpy/_src/math/tests/test_tifunc.py | 246 +-- brainpy/_src/math/tifunc.py | 53 +- brainpy/_src/tests/test_dyn_runner.py | 267 ++- brainpy/math/__init__.py | 4 + brainpy/math/tifunc.py | 1 - 31 files changed, 2739 insertions(+), 2651 deletions(-) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index a132135cc..9438c3306 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -9,10 +9,10 @@ from brainpy.errors import ConnectorError from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed from brainpy._src.tools.package import SUPPORT_NUMBA -from brainpy._src.dependency_check import import_numba_else_None +from brainpy._src.dependency_check import import_numba from .base import * -numba = import_numba_else_None() +numba = import_numba(error_if_not_found=False) __all__ = [ 'FixedProb', @@ -1350,13 +1350,13 @@ def build_coo(self, isOptimized=True): else: if numba is None: if n_dim == 1: - f = self._connect_1d_jit + f = self._connect_1d elif n_dim == 2: - f = self._connect_2d_jit + f = self._connect_2d elif n_dim == 3: - f = self._connect_3d_jit + f = self._connect_3d elif n_dim == 4: - f = self._connect_4d_jit + f = self._connect_4d else: raise NotImplementedError('Does not support the network dimension bigger than 4.') else: diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 66408c460..183e99d98 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,3 +1,4 @@ +import functools import os import sys @@ -5,8 +6,13 @@ __all__ = [ 'import_taichi', + 'raise_taichi_not_found', + 'check_taichi_func', + 'check_taichi_class', 'import_numba', - 'import_numba_else_None', + 'raise_numba_not_found', + 'check_numba_func', + 'check_numba_class', 'import_brainpylib_cpu_ops', 'import_brainpylib_gpu_ops', ] @@ -22,6 +28,9 @@ taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' '> pip install taichi==1.7.0') +numba_install_info = ('We need numba. Please install numba by pip . \n' + '> pip install numba' + ) os.environ["TI_LOG_LEVEL"] = "error" @@ -35,7 +44,7 @@ def import_taichi(error_if_not_found=True): import taichi as taichi # noqa except ModuleNotFoundError: if error_if_not_found: - raise ModuleNotFoundError(taichi_install_info) + raise raise_taichi_not_found() finally: sys.stdout = old_stdout @@ -46,44 +55,65 @@ def import_taichi(error_if_not_found=True): return taichi -def import_taichi_else_None(): - global taichi - if taichi is None: - with open(os.devnull, 'w') as devnull: - old_stdout = sys.stdout - sys.stdout = devnull - try: - import taichi as taichi # noqa - except: - return None - finally: - sys.stdout = old_stdout +def raise_taichi_not_found(): + raise ModuleNotFoundError(taichi_install_info) - if taichi.__version__ != _minimal_taichi_version: - raise RuntimeError(taichi_install_info) - return taichi + +def check_taichi_func(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if taichi is None: + raise_taichi_not_found() + return func(*args, **kwargs) + + return wrapper -def import_numba(): +def check_taichi_class(cls): + class Wrapper(cls): + def __init__(self, *args, **kwargs): + if taichi is None: + raise_taichi_not_found() + super().__init__(*args, **kwargs) + + return Wrapper + + +def import_numba(error_if_not_found=True): global numba if numba is None: try: import numba as numba except ModuleNotFoundError: - raise ModuleNotFoundError('We need numba. Please install numba by pip . \n' - '> pip install numba' - ) + if error_if_not_found: + raise_numba_not_found() + else: + return None return numba -def import_numba_else_None(): - global numba - if numba is None: - try: - import numba as numba - except: - return None - return numba +def raise_numba_not_found(): + raise ModuleNotFoundError(numba_install_info) + + +def check_numba_func(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if numba is None: + raise_numba_not_found() + return func(*args, **kwargs) + + return wrapper + + +def check_numba_class(cls): + class Wrapper(cls): + def __init__(self, *args, **kwargs): + if numba is None: + raise_numba_not_found() + super().__init__(*args, **kwargs) + + return Wrapper def is_brainpylib_gpu_installed(): diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index b85308614..7a92bc8b2 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -13,7 +13,7 @@ from brainpy._src.context import share from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy._src.dependency_check import import_numba_else_None, import_taichi, import_numba +from brainpy._src.dependency_check import import_numba, import_taichi, check_numba_func, check_taichi_func from brainpy.check import is_initializer from brainpy.connect import csr2csc from brainpy.errors import MathError @@ -21,7 +21,7 @@ from brainpy.types import ArrayType, Sharding ti = import_taichi(error_if_not_found=False) -numba = import_numba_else_None() +numba = import_numba(error_if_not_found=False) __all__ = [ 'Dense', 'Linear', @@ -269,7 +269,7 @@ def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), elif new_value > w_max0: out_w[i, j] = w_max0 else: - out_w[i, j] = new_value + out_w[i, j] = new_value @ti.kernel @@ -294,12 +294,13 @@ def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), out_w[i, j] = w_max0 else: out_w[i, j] = new_value - + dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, gpu_kernel=_gpu_dense_on_pre) +@check_taichi_func def dense_on_pre(weight, spike, trace, w_min, w_max): if w_min is None: w_min = -np.inf @@ -308,8 +309,6 @@ def dense_on_pre(weight, spike, trace, w_min, w_max): trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) - if dense_on_pre_prim is None: - import_taichi() return dense_on_pre_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] @@ -346,6 +345,7 @@ def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), else: out_w[j, i] = new_value + @ti.kernel def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), spike: ti.types.ndarray(ndim=1), @@ -369,10 +369,12 @@ def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), else: out_w[j, i] = new_value + dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, gpu_kernel=_gpu_dense_on_post) +@check_taichi_func def dense_on_post(weight, spike, trace, w_min, w_max): if w_min is None: w_min = -np.inf @@ -638,7 +640,7 @@ def stdp_update( raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: # update on presynaptic spike + if on_pre is not None: # update on presynaptic spike spike = on_pre['spike'] trace = on_pre['trace'] self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) @@ -703,6 +705,7 @@ def _batch_csrmv(self, x): return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) + class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. @@ -752,6 +755,7 @@ def _batch_csrmv(self, x): shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) + # @numba.njit(nogil=True, fastmath=True, parallel=False) # def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): # out_w[:] = w @@ -785,6 +789,8 @@ def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), for k in range(indptr[i], indptr[i + 1]): j = indices[k] out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + + @ti.kernel def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), indices: ti.types.ndarray(ndim=1), @@ -810,6 +816,7 @@ def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), gpu_kernel=_gpu_csr_on_pre_update) +@check_taichi_func def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf @@ -823,6 +830,7 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + csc_on_pre_update_prim = None if numba is not None: @numba.njit(nogil=True, fastmath=True, parallel=False) @@ -841,18 +849,16 @@ def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_ma csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) +@check_numba_func def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - if csc_on_pre_update_prim is None: - import_numba() return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - class CSCLinear(Layer): r"""Synaptic matrix multiplication with CSC sparse computation. diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 7fc89526c..41844cc8f 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,217 +1,222 @@ -import brainpy as bp -from absl.testing import parameterized -from absl.testing import absltest -import brainpy.math as bm - - -class TestLinear(parameterized.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bm.random.seed() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - num_out=[20, 10, 5] - ) - def test_Dense1(self, size, num_out): - bm.random.seed() - f = bp.dnn.Linear(10, num_out) - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size[:-1] + (num_out,)) - bm.clear_buffer_memory() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - ) - def test_Identity(self, size): - bm.random.seed() - f = bp.dnn.Identity() - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size) - bm.clear_buffer_memory() - - def test_AllToAll1(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((8, 10)) - y = f(x) - expected = bm.sum(x, axis=1, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((10,)) - y = f(x) - expected = bm.sum(x, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - def test_OneToOne(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((8, 10)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((10,)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - # bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_MaskedLinear(self, conn): - bm.random.seed() - bm.random.DEFAULT.seed(123) - f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_CSRLinear(self, conn): - bm.random.seed() - f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_EventCSRLinear(self,conn): - bm.random.seed() - f=bp.layers.EventCSRLinear(conn,weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - -if __name__ == '__main__': - absltest.main() +import pytest +import brainpy as bp +from absl.testing import parameterized +from absl.testing import absltest +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class TestLinear(parameterized.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bm.random.seed() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + num_out=[20, 10, 5] + ) + def test_Dense1(self, size, num_out): + bm.random.seed() + f = bp.dnn.Linear(10, num_out) + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size[:-1] + (num_out,)) + bm.clear_buffer_memory() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + ) + def test_Identity(self, size): + bm.random.seed() + f = bp.dnn.Identity() + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size) + bm.clear_buffer_memory() + + def test_AllToAll1(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((8, 10)) + y = f(x) + expected = bm.sum(x, axis=1, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((10,)) + y = f(x) + expected = bm.sum(x, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + bm.clear_buffer_memory() + + def test_OneToOne(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((8, 10)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((10,)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + # bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_MaskedLinear(self, conn): + bm.random.seed() + bm.random.DEFAULT.seed(123) + f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_CSRLinear(self, conn): + bm.random.seed() + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_EventCSRLinear(self, conn): + bm.random.seed() + f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + +if __name__ == '__main__': + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 0d754976f..0c0107573 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -1,800 +1,805 @@ -import brainpy.math as bm -from absl.testing import parameterized -from absl.testing import absltest -import brainpy as bp - - -class Test_Conv(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv2_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv3_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose2d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose3d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - -class TestPool(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MinPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AvgPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.MaxPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - -class Test_Dropout(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Dropout(self, mode): - bp.share.save(fit=False) - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.Dropout(prob=0.2, - mode=mode) - output = layer(input) - - -class Test_function(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Flatten(self, mode): - bm.random.seed() - layer = bp.dnn.Flatten(mode=mode) - input = bm.random.randn(10, 5, 5, 5, 4) - output = layer(input) - - -class Test_linear(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_linear(self, mode): - bm.random.seed() - input = bm.random.randn(10, 9, 8, 7) - layer = bp.dnn.Linear(num_in=7, - num_out=6, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AllToAll(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.AllToAll(num_pre=10, - num_post=20, - weight=0.1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_OneToOne(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.OneToOne(num=10, - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaskedLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_CSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventCSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - -class Test_Normalization(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm1d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm1d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm2d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm2d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm3d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm3d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 7, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_LayerNorm(self, mode): - bm.random.seed() - layer = bp.dnn.LayerNorm(normalized_shape=3, - mode=mode, - elementwise_affine=False - ) - input = bm.random.randn(10, 5, 3) - outout = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_GroupNorm(self, mode): - bm.random.seed() - layer = bp.dnn.GroupNorm(num_groups=2, - num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_InstanceNorm(self, mode): - bm.random.seed() - layer = bp.dnn.InstanceNorm(num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - -if __name__ == '__main__': - absltest.main() +import pytest +import brainpy.math as bm +from absl.testing import parameterized +from absl.testing import absltest +import brainpy as bp + +from brainpy._src.dependency_check import import_taichi +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class Test_Conv(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv2_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv3_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose2d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose3d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + +class TestPool(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MinPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AvgPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.MaxPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + +class Test_Dropout(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Dropout(self, mode): + bp.share.save(fit=False) + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.Dropout(prob=0.2, + mode=mode) + output = layer(input) + + +class Test_function(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Flatten(self, mode): + bm.random.seed() + layer = bp.dnn.Flatten(mode=mode) + input = bm.random.randn(10, 5, 5, 5, 4) + output = layer(input) + + +class Test_linear(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_linear(self, mode): + bm.random.seed() + input = bm.random.randn(10, 9, 8, 7) + layer = bp.dnn.Linear(num_in=7, + num_out=6, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AllToAll(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.AllToAll(num_pre=10, + num_post=20, + weight=0.1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_OneToOne(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.OneToOne(num=10, + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaskedLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_CSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventCSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPUniformLinear(num_in=100, + num_out=200, + prob=0.1, + w_low=-0.01, + w_high=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPUniformLinear(num_in=100, + num_out=200, + prob=0.1, + w_low=-0.01, + w_high=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + +class Test_Normalization(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm1d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm1d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm2d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm2d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm3d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm3d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 7, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_LayerNorm(self, mode): + bm.random.seed() + layer = bp.dnn.LayerNorm(normalized_shape=3, + mode=mode, + elementwise_affine=False + ) + input = bm.random.randn(10, 5, 3) + outout = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_GroupNorm(self, mode): + bm.random.seed() + layer = bp.dnn.GroupNorm(num_groups=2, + num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_InstanceNorm(self, mode): + bm.random.seed() + layer = bp.dnn.InstanceNorm(num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) + + +if __name__ == '__main__': + absltest.main() diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index e78ae5048..7ffc4e763 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,13 +1,20 @@ # -*- coding: utf-8 -*- - +import pytest import numpy as np from absl.testing import parameterized import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + bm.set_platform('cpu') + + class Test_STDP(parameterized.TestCase): @parameterized.product( @@ -117,4 +124,3 @@ def run(i, I_pre, I_post): # plt.show() bm.clear_buffer_memory() - diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index 90500a26f..eec2c9459 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -1,439 +1,444 @@ -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp -import brainpy.math as bm - -neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - - -def test_ProjAlignPreMg1(): - class EICOBA_PreAlign(bp.DynamicalSystem): - def __init__(self, scale=1., inp=20., delay=None): - super().__init__() - - self.inp = inp - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.I, - ) - self.E2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.E, - ) - self.I2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PreAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PreAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - plt.close() - bm.clear_buffer_memory() - - -def test_ProjAlignPostMg2(): - class EICOBA_PostAlign(bp.DynamicalSystem): - def __init__(self, scale, inp=20., ltc=True, delay=None): - super().__init__() - self.inp = inp - - if ltc: - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - else: - self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2E = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.E, - ) - self.E2I = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.I, - ) - self.I2E = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PostAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PostAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PostAlign(0.5, ltc=False) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - plt.close() - bm.clear_buffer_memory() - - -def test_ProjAlignPost1(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.num_exc = int(3200 * scale) - self.num_inh = num - self.num_exc - prob = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), - syn=bp.dyn.Expon(size=num, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), - syn=bp.dyn.Expon(size=num, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:self.num_exc]) - self.I(spk[self.num_exc:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet(0.5) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPost2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale, delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (ne + ni) - - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(0.5, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(0.5, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_VanillaProj(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=0.5): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPreMg1_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPreMg2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(scale=0.2, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(scale=0.2, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_vanalla_proj_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 1.)) - self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N - ) - self.I = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N - ) - - def update(self, input): - spk = self.delay.at('delay') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) - bp.visualize.raster_plot(indices, spks, show=True) - plt.close() - bm.clear_buffer_memory() - +import pytest +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + + +def test_ProjAlignPreMg1(): + class EICOBA_PreAlign(bp.DynamicalSystem): + def __init__(self, scale=1., inp=20., delay=None): + super().__init__() + + self.inp = inp + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.I, + ) + self.E2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.E, + ) + self.I2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PreAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PreAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPostMg2(): + class EICOBA_PostAlign(bp.DynamicalSystem): + def __init__(self, scale, inp=20., ltc=True, delay=None): + super().__init__() + self.inp = inp + + if ltc: + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + else: + self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2E = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E, + ) + self.E2I = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I, + ) + self.I2E = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PostAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PostAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PostAlign(0.5, ltc=False) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPost1(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.num_exc = int(3200 * scale) + self.num_inh = num - self.num_exc + prob = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), + syn=bp.dyn.Expon(size=num, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), + syn=bp.dyn.Expon(size=num, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:self.num_exc]) + self.I(spk[self.num_exc:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet(0.5) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPost2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale, delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (ne + ni) + + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(0.5, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(0.5, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_VanillaProj(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=0.5): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg1_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(scale=0.2, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(scale=0.2, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_vanalla_proj_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 1.)) + self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N + ) + self.I = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N + ) + + def update(self, input): + spk = self.delay.at('delay') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) + bp.visualize.raster_plot(indices, spks, show=True) + plt.close() + bm.clear_buffer_memory() diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index badb60832..c3936f685 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -1,126 +1,130 @@ -# -*- coding: utf-8 -*- - - -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -from brainpy._src.dynold.synapses import abstract_models - - -class Test_Abstract_Synapse(parameterized.TestCase): - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_all2all_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(bp.synapses, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_one2one_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - comp_type=['sparse', 'dense'], - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_sparse_synapse(self, comp_type, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - post_ref_key=[None, 'refractory'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_delta_synapse(self, post_ref_key, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5, ref_var=True) - post_neu = bp.neurons.LIF(3, ref_var=True) - syn = bp.synapses.Delta(pre_neu, post_neu, - conn=bp.conn.All2All(), - post_ref_key=post_ref_key, - stp=stp, ) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - pre_expected_shape = (100, 5) - post_expected_shape = (100, 3) - if isinstance(mode, bm.BatchingMode): - pre_expected_shape = (mode.batch_size,) + pre_expected_shape - post_expected_shape = (mode.batch_size,) + post_expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dynold.synapses import abstract_models +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class Test_Abstract_Synapse(parameterized.TestCase): + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_all2all_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(bp.synapses, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_one2one_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + comp_type=['sparse', 'dense'], + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_sparse_synapse(self, comp_type, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + post_ref_key=[None, 'refractory'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_delta_synapse(self, post_ref_key, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5, ref_var=True) + post_neu = bp.neurons.LIF(3, ref_var=True) + syn = bp.synapses.Delta(pre_neu, post_neu, + conn=bp.conn.All2All(), + post_ref_key=post_ref_key, + stp=stp, ) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + pre_expected_shape = (100, 5) + post_expected_shape = (100, 3) + if isinstance(mode, bm.BatchingMode): + pre_expected_shape = (mode.batch_size,) + pre_expected_shape + post_expected_shape = (mode.batch_size,) + post_expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) + bm.clear_buffer_memory() diff --git a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py index 395868092..01a315261 100644 --- a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py @@ -1,103 +1,108 @@ -# -*- coding: utf-8 -*- - - -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -biological_models = [ - bp.synapses.AMPA, - bp.synapses.GABAa, - bp.synapses.BioNMDA, -] - - -class Test_Biological_Synapse(parameterized.TestCase): - @parameterized.product( - synapse=biological_models, - delay_step=[None, 5, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_all2all_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - synapse=biological_models, - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_one2one_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - synapse=biological_models, - comp_method=['sparse', 'dense'], - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(10) - post_neu = bp.neurons.LIF(10) - syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), - comp_method=comp_method, delay_step=delay_step, - stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 10) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +biological_models = [ + bp.synapses.AMPA, + bp.synapses.GABAa, + bp.synapses.BioNMDA, +] + + +class Test_Biological_Synapse(parameterized.TestCase): + @parameterized.product( + synapse=biological_models, + delay_step=[None, 5, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_all2all_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + synapse=biological_models, + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_one2one_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + synapse=biological_models, + comp_method=['sparse', 'dense'], + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(10) + post_neu = bp.neurons.LIF(10) + syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), + comp_method=comp_method, delay_step=delay_step, + stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 10) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 1571ea922..bac809388 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -17,7 +17,7 @@ import numpy as np from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, check_taichi_func from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi @@ -30,7 +30,7 @@ ti = import_taichi(error_if_not_found=False) - +@check_taichi_func def csrmv( data: Union[float, jax.Array], indices: jax.Array, diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 1641c9db9..67e09d0a4 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -12,7 +12,7 @@ from brainpy._src.dependency_check import import_taichi -if import_taichi() is None: +if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index f389c3773..b2de30b01 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -6,7 +6,7 @@ import numpy as np from jax import numpy as jnp -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, check_taichi_func from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn._matvec import (mv_prob_homo, mv_prob_uniform, @@ -30,7 +30,7 @@ 'event_mv_prob_normal', ] - +@check_taichi_func def event_mv_prob_homo( events: jax.Array, weight: float, @@ -49,7 +49,7 @@ def event_mv_prob_homo( event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ - +@check_taichi_func def event_mv_prob_uniform( events: jax.Array, w_low: float, @@ -69,7 +69,7 @@ def event_mv_prob_uniform( event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ - +@check_taichi_func def event_mv_prob_normal( events: jax.Array, w_mu: float, diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 4b8fe004a..894294c79 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -8,7 +8,7 @@ from jax import numpy as jnp from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, check_taichi_func from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp @@ -23,6 +23,48 @@ ] +def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + if vector.ndim != 1: + raise ValueError('vector should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + + assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + + for weight in weights: + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + out_shape = (shape[1],) + if vector.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') + shape = _reverse(shape) + else: + if vector.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') + out_shape = (shape[0],) + + return shape, out_shape + + +def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + +@check_taichi_func def mv_prob_homo( vector: Union[Array, jax.Array], weight: float, @@ -85,6 +127,7 @@ def mv_prob_homo( outdim_parallel=outdim_parallel) +@check_taichi_func def mv_prob_uniform( vector: jax.Array, w_low: float, @@ -150,6 +193,7 @@ def mv_prob_uniform( outdim_parallel=outdim_parallel) +@check_taichi_func def mv_prob_normal( vector: jax.Array, w_mu: float, @@ -456,6 +500,150 @@ def mv_prob_normal_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] +def raw_mv_prob_homo( + vector: jax.Array, + weight: jax.Array, # vector with size 1 + clen: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + prim = _mv_prob_homo_outdim_parallel_p + else: + prim = _mv_prob_homo_p + + return prim(vector, + weight, + clen, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_mv_prob_uniform( + vector: jax.Array, + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + prim = _mv_prob_uniform_outdim_parallel_p + else: + prim = _mv_prob_uniform_p + + return prim(vector, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def raw_mv_prob_normal( + vector: jax.Array, + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + prim = _mv_prob_normal_outdim_parallel_p + else: + prim = _mv_prob_normal_p + + return prim(vector, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + +def _mv_prob_homo_transpose( + ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), weight, clen, seed + else: + dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, weight, clen, seed + elif ad.is_undefined_primal(weight): + if type(ct) is ad.Zero: + return vector, ad.Zero(weight), clen, seed + else: + row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + dw = jnp.sum(row * vector, keepdims=True) + return vector, dw, clen, seed + else: + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def _mv_prob_uniform_transpose( + ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_low, w_high, clen, seed + else: + dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_low, w_high, clen, seed + else: + assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' + assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + +def _mv_prob_normal_transpose( + ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_mu, w_sigma, clen, seed + else: + dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_mu, w_sigma, clen, seed + else: + assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' + assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + def _reverse(shape): return shape[::-1] @@ -463,6 +651,7 @@ def _reverse(shape): if ti is not None: from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + @ti.kernel def _mv_prob_homo_cpu( vector: ti.types.ndarray(ndim=1), @@ -575,105 +764,18 @@ def _mv_prob_homo_outdim_parallel_gpu( def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) - def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed - else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed - else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - - def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - - def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) @@ -834,51 +936,7 @@ def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, transpose=transpose, outdim_parallel=outdim_parallel) - def _mv_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_low, w_high, clen, seed - else: - dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_low, w_high, clen, seed - else: - assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' - assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - - def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p - else: - prim = _mv_prob_uniform_p - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): @@ -1045,51 +1103,7 @@ def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, out transpose=transpose, outdim_parallel=outdim_parallel) - def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed - else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - - def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p - - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) + def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 034885ae9..d8e086540 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -9,7 +9,7 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi -if import_taichi() is None: +if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index caee4efbe..8a0ae444d 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -9,7 +9,7 @@ import brainpy.math as bm from brainpy._src.dependency_check import import_taichi -if import_taichi() is None: +if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 1824ac911..ead0cf00e 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -4,8 +4,9 @@ import jax import numpy as np from jax.interpreters import xla, batching, ad, mlir -from numba.core.dispatcher import Dispatcher + +from brainpy._src.dependency_check import import_numba, check_numba_class, check_taichi_class from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -20,6 +21,10 @@ from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp +numba = import_numba(error_if_not_found=False) +if numba is not None: + from numba.core.dispatcher import Dispatcher + __all__ = [ 'XLACustomOp', ] @@ -35,7 +40,8 @@ def shape(self) -> Tuple[int, ...]: def dtype(self) -> np.dtype: ... - +@check_numba_class +@check_taichi_class class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 13d4f66e7..2af5637b4 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -8,198 +8,197 @@ from jax.interpreters import xla, batching, ad from jax.tree_util import tree_map -from brainpy._src.dependency_check import import_numba_else_None +from brainpy._src.dependency_check import import_numba, check_numba_func, check_numba_class from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +numba = import_numba(error_if_not_found=False) -numba = import_numba_else_None() +from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba if numba is not None: - from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba from numba.core.dispatcher import Dispatcher - __all__ = [ - 'CustomOpByNumba', - 'register_op_with_numba', - 'compile_cpu_signature_with_numba', - ] - - - class CustomOpByNumba(BrainPyObject): - """Creating a XLA custom call operator with Numba JIT on CPU backend. - - Parameters - ---------- - name: str - The name of operator. - eval_shape: callable - The function to evaluate the shape and dtype of the output according to the input. - This function should receive the abstract information of inputs, and return the - abstract information of the outputs. For example: - - >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): - >>> return out1_info, out2_info - con_compute: callable - The function to make the concrete computation. This function receives inputs, - and returns outputs. For example: - - >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...): - >>> pass - """ - - def __init__( - self, - eval_shape: Callable = None, - con_compute: Callable = None, - name: str = None, - batching_translation: Callable = None, - jvp_translation: Callable = None, - transpose_translation: Callable = None, - multiple_results: bool = True, - ): - super().__init__(name=name) - - # abstract evaluation function - if eval_shape is None: - raise ValueError('Must provide "eval_shape" for abstract evaluation.') - - # cpu function - cpu_func = con_compute - - # register OP - self.op = register_op_with_numba( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) - - def __call__(self, *args, **kwargs): - args = tree_map(lambda a: a.value if isinstance(a, Array) else a, - args, is_leaf=lambda a: isinstance(a, Array)) - kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a, - kwargs, is_leaf=lambda a: isinstance(a, Array)) - res = self.op.bind(*args, **kwargs) - return res - - - def register_op_with_numba( - op_name: str, - cpu_func: Callable, - out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], - gpu_func_translation: Callable = None, +__all__ = [ + 'CustomOpByNumba', + 'register_op_with_numba', + 'compile_cpu_signature_with_numba', +] + + +@check_numba_class +class CustomOpByNumba(BrainPyObject): + """Creating a XLA custom call operator with Numba JIT on CPU backend. + + Parameters + ---------- + name: str + The name of operator. + eval_shape: callable + The function to evaluate the shape and dtype of the output according to the input. + This function should receive the abstract information of inputs, and return the + abstract information of the outputs. For example: + + >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): + >>> return out1_info, out2_info + con_compute: callable + The function to make the concrete computation. This function receives inputs, + and returns outputs. For example: + + >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...): + >>> pass + """ + + def __init__( + self, + eval_shape: Callable = None, + con_compute: Callable = None, + name: str = None, batching_translation: Callable = None, jvp_translation: Callable = None, transpose_translation: Callable = None, - multiple_results: bool = False, + multiple_results: bool = True, ): - """ - Converting the numba-jitted function in a Jax/XLA compatible primitive. - - Parameters - ---------- - op_name: str - Name of the operators. - - cpu_func: Callable - A callable numba-jitted function or pure function (can be lambda function) running on CPU. - - out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None - Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or - a sequence of `ShapedArray`. If it is a function, it takes as input the argument - shapes and dtypes and should return correct output shapes of `ShapedArray`. - - gpu_func_translation: Callable - A callable cuda-jitted kernel running on GPU. - - batching_translation: Callable - The batching translation for the primitive. - - jvp_translation: Callable - The forward autodiff translation rule. - - transpose_translation: Callable - The backward autodiff translation rule. - - multiple_results: bool - Whether the primitive returns multiple results. Default is False. - - Returns - ------- - op: core.Primitive - A JAX Primitive object. - """ - - if jax.__version__ > '0.4.23': - raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are ' - f'only supported in JAX version <= 0.4.23. \n' - f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. ' - f'For more information, please refer to the documentation: ' - f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') - - if out_shapes is None: - raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' - 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' - 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') - - prim = jax.core.Primitive(op_name) - prim.multiple_results = multiple_results - - # user defined function - if not isinstance(cpu_func, Dispatcher): - cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) - - # output shape evaluation function - def abs_eval_rule(*input_shapes, **info): - if callable(out_shapes): - shapes = out_shapes(*input_shapes, **info) - else: - shapes = out_shapes - - if isinstance(shapes, jax.core.ShapedArray): - assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." - elif isinstance(shapes, (tuple, list)): - assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." - for elem in shapes: - if not isinstance(elem, jax.core.ShapedArray): - raise ValueError(f'Elements in "out_shapes" must be instances of ' - f'jax.abstract_arrays.ShapedArray, but we got ' - f'{type(elem)}: {elem}') - else: - raise ValueError(f'Unknown type {type(shapes)}, only ' - f'supports function, ShapedArray or ' - f'list/tuple of ShapedArray.') - return shapes + super().__init__(name=name) - # cpu function - prim.def_abstract_eval(abs_eval_rule) - prim.def_impl(partial(xla.apply_primitive, prim)) - xla.backend_specific_translations['cpu'][prim] = partial(_cpu_translation, - cpu_func, - abs_eval_rule, - multiple_results) - - # gpu function - if gpu_func_translation is not None: - xla.backend_specific_translations['gpu'][prim] = gpu_func_translation - - # batching - if batching_translation is not None: - batching.primitive_batchers[prim] = batching_translation - - # jvp - if jvp_translation is not None: - ad.primitive_jvps[prim] = jvp_translation + # abstract evaluation function + if eval_shape is None: + raise ValueError('Must provide "eval_shape" for abstract evaluation.') - # transpose - if transpose_translation is not None: - ad.primitive_transposes[prim] = transpose_translation - - return prim - -else: - __all__ = [] + # cpu function + cpu_func = con_compute + + # register OP + self.op = register_op_with_numba( + self.name, + cpu_func=cpu_func, + out_shapes=eval_shape, + batching_translation=batching_translation, + jvp_translation=jvp_translation, + transpose_translation=transpose_translation, + multiple_results=multiple_results, + ) + + def __call__(self, *args, **kwargs): + args = tree_map(lambda a: a.value if isinstance(a, Array) else a, + args, is_leaf=lambda a: isinstance(a, Array)) + kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a, + kwargs, is_leaf=lambda a: isinstance(a, Array)) + res = self.op.bind(*args, **kwargs) + return res + + +@check_numba_func +def register_op_with_numba( + op_name: str, + cpu_func: Callable, + out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], + gpu_func_translation: Callable = None, + batching_translation: Callable = None, + jvp_translation: Callable = None, + transpose_translation: Callable = None, + multiple_results: bool = False, +): + """ + Converting the numba-jitted function in a Jax/XLA compatible primitive. + + Parameters + ---------- + op_name: str + Name of the operators. + + cpu_func: Callable + A callable numba-jitted function or pure function (can be lambda function) running on CPU. + + out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None + Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or + a sequence of `ShapedArray`. If it is a function, it takes as input the argument + shapes and dtypes and should return correct output shapes of `ShapedArray`. + + gpu_func_translation: Callable + A callable cuda-jitted kernel running on GPU. + + batching_translation: Callable + The batching translation for the primitive. + + jvp_translation: Callable + The forward autodiff translation rule. + + transpose_translation: Callable + The backward autodiff translation rule. + + multiple_results: bool + Whether the primitive returns multiple results. Default is False. + + Returns + ------- + op: core.Primitive + A JAX Primitive object. + """ + + if jax.__version__ > '0.4.23': + raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are ' + f'only supported in JAX version <= 0.4.23. \n' + f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. ' + f'For more information, please refer to the documentation: ' + f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') + + if out_shapes is None: + raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' + 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' + 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') + + prim = jax.core.Primitive(op_name) + prim.multiple_results = multiple_results + + # user defined function + if not isinstance(cpu_func, Dispatcher): + cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) + + # output shape evaluation function + def abs_eval_rule(*input_shapes, **info): + if callable(out_shapes): + shapes = out_shapes(*input_shapes, **info) + else: + shapes = out_shapes + + if isinstance(shapes, jax.core.ShapedArray): + assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." + elif isinstance(shapes, (tuple, list)): + assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." + for elem in shapes: + if not isinstance(elem, jax.core.ShapedArray): + raise ValueError(f'Elements in "out_shapes" must be instances of ' + f'jax.abstract_arrays.ShapedArray, but we got ' + f'{type(elem)}: {elem}') + else: + raise ValueError(f'Unknown type {type(shapes)}, only ' + f'supports function, ShapedArray or ' + f'list/tuple of ShapedArray.') + return shapes + + # cpu function + prim.def_abstract_eval(abs_eval_rule) + prim.def_impl(partial(xla.apply_primitive, prim)) + xla.backend_specific_translations['cpu'][prim] = partial(_cpu_translation, + cpu_func, + abs_eval_rule, + multiple_results) + + # gpu function + if gpu_func_translation is not None: + xla.backend_specific_translations['gpu'][prim] = gpu_func_translation + + # batching + if batching_translation is not None: + batching.primitive_batchers[prim] = batching_translation + + # jvp + if jvp_translation is not None: + ad.primitive_jvps[prim] = jvp_translation + + # transpose + if transpose_translation is not None: + ad.primitive_transposes[prim] = transpose_translation + + return prim diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 02f74a237..759ecc50c 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -6,150 +6,148 @@ from jax.core import ShapedArray from jax.lib import xla_client -from brainpy._src.dependency_check import import_numba_else_None +from brainpy._src.dependency_check import import_numba, check_numba_func -numba = import_numba_else_None() +numba = import_numba(error_if_not_found=False) + +__all__ = [ + '_cpu_translation', + 'compile_cpu_signature_with_numba', +] if numba is not None: from numba import types, carray, cfunc - __all__ = [ - '_cpu_translation', - 'compile_cpu_signature_with_numba', - ] - - ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, # void* pointer - ctypes.c_char_p, # const char *name - ctypes.c_void_p, # PyCapsule_Destructor destructor - ] - ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - - - def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): - target_name, inputs, input_shapes, xla_output_shapes = \ - compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_shapes, - shape_with_layout=xla_output_shapes, - ) - - - def _cpu_signature( - func, - input_dtypes, - input_shapes, - output_dtypes, - output_shapes, - multiple_results: bool, - debug: bool = False - ): - code_scope = dict( - func_to_call=func, - input_shapes=input_shapes, - input_dtypes=input_dtypes, - output_shapes=output_shapes, - output_dtypes=output_dtypes, - carray=carray, - ) - - # inputs - if len(input_shapes) > 1: - args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' - for i in range(len(input_shapes)) - ] - args_in = '(\n ' + "\n ".join(args_in) + '\n )' - else: - args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' - - # outputs - if multiple_results: - args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' - for i in range(len(output_shapes)) - ] - args_out = '(\n ' + "\n ".join(args_out) + '\n )' - else: - args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' - - # function body - code_string = ''' - def xla_cpu_custom_call_target(output_ptrs, input_ptrs): - args_out = {args_out} - args_in = {args_in} - func_to_call(args_out, args_in) - '''.format(args_in=args_in, - args_out=args_out) - if debug: print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - - new_f = code_scope['xla_cpu_custom_call_target'] - if multiple_results: - xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr)))(new_f) - else: - xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) - target_name = xla_c_rule.native_name.encode("ascii") - capsule = ctypes.pythonapi.PyCapsule_New( - xla_c_rule.address, # A CFFI pointer to a function - b"xla._CUSTOM_CALL_TARGET", # A binary string - None # PyCapsule object run at destruction - ) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - return target_name - - - def compile_cpu_signature_with_numba( - c, - func, - abs_eval_fn, - multiple_results, - inputs: tuple, - description: dict = None, - ): - input_layouts = [c.get_shape(arg) for arg in inputs] - info_inputs = [] - if description is None: description = dict() - for v in description.values(): - if isinstance(v, (int, float)): - input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) - info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) - elif isinstance(v, (tuple, list)): - v = jnp.asarray(v) - input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) - info_inputs.append(xla_client.ops.Constant(c, v)) - else: - raise TypeError - input_layouts = tuple(input_layouts) - input_dtypes = tuple(shape.element_type() for shape in input_layouts) - input_dimensions = tuple(shape.dimensions() for shape in input_layouts) - output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_layouts[:len(inputs)]), - **description) - if isinstance(output_abstract_arrays, ShapedArray): - output_abstract_arrays = (output_abstract_arrays,) - assert not multiple_results +ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, # void* pointer + ctypes.c_char_p, # const char *name + ctypes.c_void_p, # PyCapsule_Destructor destructor +] +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + +@check_numba_func +def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): + target_name, inputs, input_shapes, xla_output_shapes = \ + compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) + return xla_client.ops.CustomCallWithLayout( + c, + target_name, + operands=inputs, + operand_shapes_with_layout=input_shapes, + shape_with_layout=xla_output_shapes, + ) + + +def _cpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + multiple_results: bool, + debug: bool = False +): + code_scope = dict( + func_to_call=func, + input_shapes=input_shapes, + input_dtypes=input_dtypes, + output_shapes=output_shapes, + output_dtypes=output_dtypes, + carray=carray, + ) + + # inputs + if len(input_shapes) > 1: + args_in = [ + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' + for i in range(len(input_shapes)) + ] + args_in = '(\n ' + "\n ".join(args_in) + '\n )' + else: + args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' + + # outputs + if multiple_results: + args_out = [ + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' + for i in range(len(output_shapes)) + ] + args_out = '(\n ' + "\n ".join(args_out) + '\n )' + else: + args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' + + # function body + code_string = ''' +def xla_cpu_custom_call_target(output_ptrs, input_ptrs): + args_out = {args_out} + args_in = {args_in} + func_to_call(args_out, args_in) + '''.format(args_in=args_in, + args_out=args_out) + if debug: print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + + new_f = code_scope['xla_cpu_custom_call_target'] + if multiple_results: + xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), + types.CPointer(types.voidptr)))(new_f) + else: + xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) + target_name = xla_c_rule.native_name.encode("ascii") + capsule = ctypes.pythonapi.PyCapsule_New( + xla_c_rule.address, # A CFFI pointer to a function + b"xla._CUSTOM_CALL_TARGET", # A binary string + None # PyCapsule object run at destruction + ) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + return target_name + +@check_numba_func +def compile_cpu_signature_with_numba( + c, + func, + abs_eval_fn, + multiple_results, + inputs: tuple, + description: dict = None, +): + input_layouts = [c.get_shape(arg) for arg in inputs] + info_inputs = [] + if description is None: description = dict() + for v in description.values(): + if isinstance(v, (int, float)): + input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) + info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) + elif isinstance(v, (tuple, list)): + v = jnp.asarray(v) + input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) + info_inputs.append(xla_client.ops.Constant(c, v)) else: - assert multiple_results - output_shapes = tuple(array.shape for array in output_abstract_arrays) - output_dtypes = tuple(array.dtype for array in output_abstract_arrays) - output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) - target_name = _cpu_signature(func, - input_dtypes, - input_dimensions, - output_dtypes, - output_shapes, - multiple_results, - debug=False) - output_layouts = [xla_client.Shape.array_shape(*arg) - for arg in zip(output_dtypes, output_shapes, output_layouts)] - output_layouts = (xla_client.Shape.tuple_shape(output_layouts) - if multiple_results else - output_layouts[0]) - return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts -else: - __all__ = [] + raise TypeError + input_layouts = tuple(input_layouts) + input_dtypes = tuple(shape.element_type() for shape in input_layouts) + input_dimensions = tuple(shape.dimensions() for shape in input_layouts) + output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) + for shape in input_layouts[:len(inputs)]), + **description) + if isinstance(output_abstract_arrays, ShapedArray): + output_abstract_arrays = (output_abstract_arrays,) + assert not multiple_results + else: + assert multiple_results + output_shapes = tuple(array.shape for array in output_abstract_arrays) + output_dtypes = tuple(array.dtype for array in output_abstract_arrays) + output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) + target_name = _cpu_signature(func, + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes, + multiple_results, + debug=False) + output_layouts = [xla_client.Shape.array_shape(*arg) + for arg in zip(output_dtypes, output_shapes, output_layouts)] + output_layouts = (xla_client.Shape.tuple_shape(output_layouts) + if multiple_results else + output_layouts[0]) + return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index fd7a289ed..8c56e52aa 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -6,16 +6,19 @@ from jax.interpreters import xla, mlir from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call -from numba import types, carray, cfunc from .utils import _shape_to_layout +from brainpy._src.dependency_check import import_numba, check_numba_func + +numba = import_numba(error_if_not_found=False) +if numba is not None: + from numba import types, carray, cfunc __all__ = [ 'register_numba_xla_cpu_translation_rule', 'register_numba_mlir_cpu_translation_rule', ] - # [void* pointer, # const char *name, # PyCapsule_Destructor destructor] @@ -102,6 +105,7 @@ def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): ) +@check_numba_func def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): # do not support after jax >= 0.4.24 xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, @@ -166,6 +170,7 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs): ).results +@check_numba_func def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False): rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug) mlir.register_lowering(primitive, rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py index 24f010a12..2c9f09724 100644 --- a/brainpy/_src/math/op_register/tests/test_ad_support.py +++ b/brainpy/_src/math/op_register/tests/test_ad_support.py @@ -1,13 +1,18 @@ +import pytest from typing import Tuple import jax -import numba from jax import core from jax import numpy as jnp from jax.interpreters import ad import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index 968155ef9..dc093f624 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,6 +1,11 @@ +import pytest import jax.core import brainpy.math as bm -import numba + +from brainpy._src.dependency_check import import_numba +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 03023754c..4db38fbcb 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,9 +1,14 @@ +import pytest import jax import jax.numpy as jnp -import taichi as ti import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 1bebcdafe..5f6b3a292 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -1,54 +1,58 @@ -import brainpy.math as bm -import jax -import jax.numpy as jnp -import platform -import pytest -import taichi - -if not platform.platform().startswith('Windows'): - pytest.skip(allow_module_level=True) - -@taichi.func -def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: - return weight[0] - - -@taichi.func -def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): - out[index] += weight_val - -@taichi.kernel -def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), - vector: taichi.types.ndarray(ndim=1), - weight: taichi.types.ndarray(ndim=1), - out: taichi.types.ndarray(ndim=1)): - weight_val = get_weight(weight) - num_rows, num_cols = indices.shape - taichi.loop_config(serialize=True) - for i in range(num_rows): - if vector[i]: - for j in range(num_cols): - update_output(out, indices[i, j], weight_val) - -prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) - -def test_taichi_clean_cache(): - s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) - vector = bm.random.rand(s) < 0.1 - weight = bm.array([1.0]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - print(out) - bm.clear_buffer_memory() - - print('kernels: ', bm.check_kernels_count()) - - bm.clean_caches() - - print('kernels: ', bm.check_kernels_count()) - +import brainpy.math as bm +import jax +import jax.numpy as jnp +import platform +import pytest + +from brainpy._src.dependency_check import import_taichi +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + +if not platform.platform().startswith('Windows'): + pytest.skip(allow_module_level=True) + +@taichi.func +def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: + return weight[0] + + +@taichi.func +def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): + out[index] += weight_val + +@taichi.kernel +def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), + vector: taichi.types.ndarray(ndim=1), + weight: taichi.types.ndarray(ndim=1), + out: taichi.types.ndarray(ndim=1)): + weight_val = get_weight(weight) + num_rows, num_cols = indices.shape + taichi.loop_config(serialize=True) + for i in range(num_rows): + if vector[i]: + for j in range(num_cols): + update_output(out, indices[i, j], weight_val) + +prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + +def test_taichi_clean_cache(): + s = 1000 + indices = bm.random.randint(0, s, (s, 1000)) + vector = bm.random.rand(s) < 0.1 + weight = bm.array([1.0]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + print(out) + bm.clear_buffer_memory() + + print('kernels: ', bm.check_kernels_count()) + + bm.clean_caches() + + print('kernels: ', bm.check_kernels_count()) + # test_taichi_clean_cache() \ No newline at end of file diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 6f9d5378c..43ccac6c8 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -4,19 +4,19 @@ from typing import Tuple import jax.lax -import numba import numpy as np from jax import numpy as jnp from jax.core import Primitive, ShapedArray from jax.interpreters import ad, xla from jax.lib import xla_client -from brainpy._src.dependency_check import import_brainpylib_gpu_ops +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba, check_numba_func from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) from brainpy.errors import GPUOperatorNotFound +numba = import_numba(error_if_not_found=False) __all__ = [ 'bcsrmm', ] @@ -216,6 +216,7 @@ def blocksparse_matmat_multiply(dense_a, raise Exception('Invalid device: ', device) +@check_numba_func def bcsrmm( A_data: jax.Array, B_data: jax.Array, @@ -264,52 +265,53 @@ def bcsrmm( raise ValueError -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_k, block_size_n) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_n, block_size_k) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val +if numba is not None: + @numba.njit(fastmath=True, parallel=True, nogil=True) + def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_k, block_size_n) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val + + + @numba.njit(fastmath=True, parallel=True, nogil=True) + def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_n, block_size_k) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val def _bcsrmm_cutlass_abstract( diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 5fdb83443..dd25ef3d4 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -9,7 +9,7 @@ from jax.experimental.sparse import csr from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, check_taichi_func from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (register_general_batching, @@ -23,7 +23,7 @@ 'csrmv', ] - +@check_taichi_func def csrmv( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 418a52d35..ec448e658 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -11,7 +11,7 @@ from brainpy._src.dependency_check import import_taichi -if import_taichi() is None: +if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) seed = 1234 diff --git a/brainpy/_src/math/tests/test_tifunc.py b/brainpy/_src/math/tests/test_tifunc.py index 6823ebabd..db6e7debc 100644 --- a/brainpy/_src/math/tests/test_tifunc.py +++ b/brainpy/_src/math/tests/test_tifunc.py @@ -1,122 +1,124 @@ -# -*- coding: utf-8 -*- - -import jax -import jax.numpy as jnp -import pytest - -pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.") -import brainpy.math as bm -import taichi as ti -import matplotlib.pyplot as plt -import os - - -bm.set_platform('cpu') - - -def test_taichi_random(): - @ti.kernel - def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), - out: ti.types.ndarray(ndim=1, dtype=ti.f32)): - key = bm.tifunc.lfsr88_key(seed[0]) - for i in range(out.shape[0]): - key, result = bm.tifunc.lfsr88_rand(key) - out[i] = result - - @ti.kernel - def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range(out.shape[0]): - out[i] = bm.tifunc.taichi_lcg_rand(seed) - - @ti.kernel - def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) - - @ti.kernel - def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) - - @ti.kernel - def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), - mu_sigma: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - mu = mu_sigma[0] - sigma = mu_sigma[1] - - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) - - n = 100000 - seed = jnp.array([1234, ], dtype=jnp.uint32) - low_high = jnp.array([0, 10]) - mu_sigma = jnp.array([0, 1]) - - prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, - gpu_kernel=test_taichi_lfsr88) - - - prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, - gpu_kernel=test_taichi_lcg_rand) - prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, - gpu_kernel=test_taichi_uniform_int_distribution) - prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, - gpu_kernel=test_taichi_uniform_real_distribution) - prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, - gpu_kernel=test_taichi_normal_distribution) - - file_path = os.path.dirname(os.path.abspath(__file__)) - - out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LFSR88 random number generator") - plt.savefig(file_path + "/lfsr88.png") - plt.close() - - out = prim_lcg_rand(seed, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LCG random number generator") - plt.savefig(file_path + "/lcg_rand.png") - plt.close() - - out = prim_uniform_int_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) - # show the distribution of out - plt.hist(out, bins=10) - plt.title("Uniform int distribution (0, 10)") - plt.savefig(file_path + "/uniform_int_distribution.png") - plt.close() - - out = prim_uniform_real_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("Uniform real distribution (0, 10)") - plt.savefig(file_path + "/uniform_real_distribution.png") - plt.close() - - out = prim_normal_distribution(seed, mu_sigma, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.title("Normal distribution mu=0, sigma=1") - plt.hist(out, bins=100) - plt.savefig(file_path + "/normal_distribution.png") - - -# TODO; test default types +# -*- coding: utf-8 -*- + +import jax +import jax.numpy as jnp +import pytest + +pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.") +import brainpy.math as bm +import matplotlib.pyplot as plt +import os + +from brainpy._src.dependency_check import import_taichi + +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + +bm.set_platform('cpu') + + +def test_taichi_random(): + @ti.kernel + def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), + out: ti.types.ndarray(ndim=1, dtype=ti.f32)): + key = bm.tifunc.lfsr88_key(seed[0]) + for i in range(out.shape[0]): + key, result = bm.tifunc.lfsr88_rand(key) + out[i] = result + + @ti.kernel + def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range(out.shape[0]): + out[i] = bm.tifunc.taichi_lcg_rand(seed) + + @ti.kernel + def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) + + @ti.kernel + def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) + + @ti.kernel + def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), + mu_sigma: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + mu = mu_sigma[0] + sigma = mu_sigma[1] + + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) + + n = 100000 + seed = jnp.array([1234, ], dtype=jnp.uint32) + low_high = jnp.array([0, 10]) + mu_sigma = jnp.array([0, 1]) + + prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, + gpu_kernel=test_taichi_lfsr88) + + prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, + gpu_kernel=test_taichi_lcg_rand) + prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, + gpu_kernel=test_taichi_uniform_int_distribution) + prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, + gpu_kernel=test_taichi_uniform_real_distribution) + prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, + gpu_kernel=test_taichi_normal_distribution) + + file_path = os.path.dirname(os.path.abspath(__file__)) + + out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LFSR88 random number generator") + plt.savefig(file_path + "/lfsr88.png") + plt.close() + + out = prim_lcg_rand(seed, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LCG random number generator") + plt.savefig(file_path + "/lcg_rand.png") + plt.close() + + out = prim_uniform_int_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) + # show the distribution of out + plt.hist(out, bins=10) + plt.title("Uniform int distribution (0, 10)") + plt.savefig(file_path + "/uniform_int_distribution.png") + plt.close() + + out = prim_uniform_real_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("Uniform real distribution (0, 10)") + plt.savefig(file_path + "/uniform_real_distribution.png") + plt.close() + + out = prim_normal_distribution(seed, mu_sigma, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.title("Normal distribution mu=0, sigma=1") + plt.hist(out, bins=100) + plt.savefig(file_path + "/normal_distribution.png") + +# TODO; test default types diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py index c54f4d6f7..9cfd39e1a 100644 --- a/brainpy/_src/math/tifunc.py +++ b/brainpy/_src/math/tifunc.py @@ -1,51 +1,27 @@ -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, raise_taichi_not_found from . import defaults ti = import_taichi(error_if_not_found=False) -if ti is not None: - - __all__ = [ - # taichi function for other utilities - 'warp_reduce_sum', - - # taichi functions for random number generator with LFSR88 algorithm - 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn', - 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand', - - # taichi functions for random number generator with LFSR113 algorithm - 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn', - 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', - ] - +__all__ = [ + # taichi function for other utilities + 'warp_reduce_sum', - @ti.func - def _lcg_rand(state: ti.types.ndarray(ndim=1)): - # LCG constants - state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223) - return state[0] + # taichi functions for random number generator with LFSR88 algorithm + 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn', + 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand', + # taichi functions for random number generator with LFSR113 algorithm + 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn', + 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', +] - @ti.func - def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)): - """ - Generate a random number using the Taichi LCG algorithm. - - Parameters: - seed (ti.types.ndarray): The seed value for the random number generator. - - Returns: - float: A random number between 0 and 1. - """ - - return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1) - +if ti is not None: ############################################# # Random Number Generator: LFSR88 algorithm # ############################################# - @ti.func def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). @@ -186,7 +162,6 @@ def lfsr88_rand(key: ti.types.vector(4, ti.u32)): # Random Number Generator: LFSR113 algorithm # ############################################## - @ti.func def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). @@ -333,7 +308,6 @@ def lfsr113_rand(key: ti.types.vector(4, ti.u32)): # Reductions: warp reduce # ########################### - @ti.func def warp_reduce_sum_all(val): """ @@ -367,4 +341,5 @@ def warp_reduce_sum(val): else: - __all__ = [] + for func in __all__: + globals()[func] = raise_taichi_not_found \ No newline at end of file diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py index dd6865e64..6f2411ee8 100644 --- a/brainpy/_src/tests/test_dyn_runner.py +++ b/brainpy/_src/tests/test_dyn_runner.py @@ -1,134 +1,133 @@ -# -*- coding: utf-8 -*- - - -import unittest -import brainpy as bp -import brainpy.math as bm - - -class TestDSRunner(unittest.TestCase): - def test1(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 - - ds = ExampleDS() - runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_t_and_dt(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 * bp.share['dt'] - - runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_DSView(self): - class EINet(bp.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet, self).__init__() - - # network size - num_exc = int(800 * scale) - num_inh = int(200 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - - bm.random.seed() - - net = EINet(scale=1., method='exp_auto') - # with JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) - - # without JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) - - - -class TestMemoryEfficient(unittest.TestCase): - pass - - - - - - -# class TestMonitor(TestCase): -# def test_1d_array(self): -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones(1) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 -# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) -# -# def test_2d_array(): -# set(dt=0.1) -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones((2, 2)) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# def test_monitor_with_every(): -# set(dt=0.1) -# -# # try1: 2d array -# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try1.run(100.) -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# # try2: 1d array -# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try2.a = np.array([1., 1.]) -# try2.run(100.) -# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 2, axis=1) -# assert np.allclose(series, try2.mon.a) -# -# # try2: scalar -# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try3.a = 1. -# try3.run(100.) -# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# assert np.allclose(series, try3.mon.a) +# -*- coding: utf-8 -*- + +import pytest +import unittest +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class TestDSRunner(unittest.TestCase): + def test1(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 + + ds = ExampleDS() + runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_t_and_dt(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 * bp.share['dt'] + + runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_DSView(self): + class EINet(bp.Network): + def __init__(self, scale=1.0, method='exp_auto'): + super(EINet, self).__init__() + + # network size + num_exc = int(800 * scale) + num_inh = int(200 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) + self.E = bp.neurons.LIF(num_exc, **pars, method=method) + self.I = bp.neurons.LIF(num_inh, **pars, method=method) + self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. + self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + + bm.random.seed() + + net = EINet(scale=1., method='exp_auto') + # with JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) + + # without JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) + + +class TestMemoryEfficient(unittest.TestCase): + pass + +# class TestMonitor(TestCase): +# def test_1d_array(self): +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones(1) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 +# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) +# +# def test_2d_array(): +# set(dt=0.1) +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones((2, 2)) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# def test_monitor_with_every(): +# set(dt=0.1) +# +# # try1: 2d array +# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try1.run(100.) +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# # try2: 1d array +# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try2.a = np.array([1., 1.]) +# try2.run(100.) +# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 2, axis=1) +# assert np.allclose(series, try2.mon.a) +# +# # try2: scalar +# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try3.a = 1. +# try3.run(100.) +# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# assert np.allclose(series, try3.mon.a) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index feaa10093..9a64f9f25 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -53,6 +53,10 @@ from brainpy._src.math import defaults from brainpy._src.deprecations import deprecation_getattr +from brainpy._src.dependency_check import import_taichi, import_numba + +import_taichi(error_if_not_found=False) +import_numba(error_if_not_found=False) __deprecations = { "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py index e345c6835..bea49c220 100644 --- a/brainpy/math/tifunc.py +++ b/brainpy/math/tifunc.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from brainpy._src.math.tifunc import ( - taichi_lcg_rand, # warp reduction primitives warp_reduce_sum, From ce43e4e8b1f06ba0a343bf1c514ace6ad78b749a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 23 Feb 2024 13:57:24 +0800 Subject: [PATCH 07/16] Update test_taichi_clean_cache.py --- .../tests/test_taichi_clean_cache.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 5f6b3a292..51c964b29 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -12,23 +12,23 @@ if not platform.platform().startswith('Windows'): pytest.skip(allow_module_level=True) -@taichi.func -def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: +@ti.func +def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: return weight[0] -@taichi.func -def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): +@ti.func +def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): out[index] += weight_val -@taichi.kernel -def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), - vector: taichi.types.ndarray(ndim=1), - weight: taichi.types.ndarray(ndim=1), - out: taichi.types.ndarray(ndim=1)): +@ti.kernel +def event_ell_cpu(indices: ti.types.ndarray(ndim=2), + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape - taichi.loop_config(serialize=True) + ti.loop_config(serialize=True) for i in range(num_rows): if vector[i]: for j in range(num_cols): From 86cc09614cb2af3ec5773dad060d061b08bba828 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 23 Feb 2024 14:04:18 +0800 Subject: [PATCH 08/16] Update CI and remove taichi, numba from requirements --- .github/workflows/CI.yml | 62 ++++++++++++++++++++++++++++++++++++++++ requirements-dev.txt | 2 -- 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 84aa028e3..b82507108 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,6 +50,37 @@ jobs: cd brainpy pytest _src/ + test_linux_with_taichi_numba: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ "3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest taichi numba + if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + pip uninstall brainpy -y + python setup.py install + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + cd brainpy + pytest _src/ + # test_linux_py37: # runs-on: ubuntu-latest @@ -116,6 +147,37 @@ jobs: cd brainpy pytest _src/ + test_macos_with_taichi_numba: + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest taichi numba + if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + pip uninstall brainpy -y + python setup.py install + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + cd brainpy + pytest _src/ + # test_macos_py37: # runs-on: macos-latest # strategy: diff --git a/requirements-dev.txt b/requirements-dev.txt index 0e475e83d..167f39df9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,4 @@ numpy -numba brainpylib jax jaxlib @@ -7,7 +6,6 @@ matplotlib msgpack tqdm pathos -taichi==1.7.0 # test requirements pytest From 76202d5baa427d5698f566a1260a66d8e4f51c8e Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 23 Feb 2024 14:31:44 +0800 Subject: [PATCH 09/16] Resolve conflicts --- brainpy/_src/dependency_check.py | 2 +- brainpy/_src/dnn/tests/test_linear.py | 5 +++-- brainpy/_src/dnn/tests/test_mode.py | 6 ++++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 183e99d98..2babb5023 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -17,7 +17,7 @@ 'import_brainpylib_gpu_ops', ] -_minimal_brainpylib_version = '0.1.10' +_minimal_brainpylib_version = '0.2.6' _minimal_taichi_version = (1, 7, 0) taichi = None diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 41844cc8f..422f161f1 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,7 +1,8 @@ import pytest -import brainpy as bp -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + +import brainpy as bp import brainpy.math as bm from brainpy._src.dependency_check import import_taichi diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 0c0107573..f0c67da12 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -1,10 +1,12 @@ import pytest -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm from brainpy._src.dependency_check import import_taichi + if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) From b46c3bf4b2dc30121d765181022891f595f77dff Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 23 Feb 2024 14:34:13 +0800 Subject: [PATCH 10/16] Revert "Merge branch 'master' into dependency-optimize" This reverts commit c54c82214f6f463d9cf0cb9b53d469fe9521edd1, reversing changes made to 76202d5baa427d5698f566a1260a66d8e4f51c8e. --- brainpy/_src/dependency_check.py | 88 -- brainpy/_src/dnn/conv.py | 11 +- brainpy/_src/dnn/tests/test_activation.py | 3 +- brainpy/_src/dnn/tests/test_conv_layers.py | 11 +- brainpy/_src/dnn/tests/test_function.py | 6 +- brainpy/_src/dnn/tests/test_linear.py | 219 ----- brainpy/_src/dnn/tests/test_mode.py | 802 ------------------ brainpy/_src/dnn/tests/test_normalization.py | 5 +- brainpy/_src/dnn/tests/test_pooling_layers.py | 2 +- brainpy/_src/math/delayvars.py | 5 +- .../_src/math/object_transform/autograd.py | 45 +- brainpy/_src/math/object_transform/base.py | 4 +- .../_src/math/object_transform/controls.py | 136 +-- brainpy/_src/math/object_transform/jit.py | 69 +- brainpy/_src/math/object_transform/naming.py | 3 +- .../_src/math/object_transform/parallels.py | 460 ++++++++++ brainpy/_src/math/object_transform/tools.py | 75 +- .../_src/math/object_transform/variables.py | 45 +- brainpy/_src/tools/functions.py | 192 ----- brainpy/_src/tools/tests/test_functions.py | 24 - brainpy/math/compat_pytorch.py | 2 +- brainpy/math/oo_transform.py | 4 - brainpy/tools.py | 5 - docs/advanced_tutorials.rst | 51 +- docs/apis/brainpy.math.oo_transform.rst | 1 - docs/toolboxes.rst | 38 +- docs/tutorials.rst | 77 +- examples/dynamics_simulation/ei_nets.py | 2 +- 28 files changed, 699 insertions(+), 1686 deletions(-) create mode 100644 brainpy/_src/math/object_transform/parallels.py delete mode 100644 brainpy/_src/tools/functions.py delete mode 100644 brainpy/_src/tools/tests/test_functions.py diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 68558a9a4..2babb5023 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -142,94 +142,6 @@ def import_brainpylib_cpu_ops(): return brainpylib_cpu_ops -def import_brainpylib_gpu_ops(): - global brainpylib_gpu_ops - if brainpylib_gpu_ops is None: - try: - from brainpylib import gpu_ops as brainpylib_gpu_ops - - for _name, _value in brainpylib_gpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="gpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install GPU version of brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_gpu_ops -======= -import os -import sys -from jax.lib import xla_client - -__all__ = [ - 'import_taichi', - 'import_brainpylib_cpu_ops', - 'import_brainpylib_gpu_ops', -] - -_minimal_brainpylib_version = '0.2.6' -_minimal_taichi_version = (1, 7, 0) - -taichi = None -brainpylib_cpu_ops = None -brainpylib_gpu_ops = None - -taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' - f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' - '> pip install taichi==1.7.0') -os.environ["TI_LOG_LEVEL"] = "error" - - -def import_taichi(): - global taichi - if taichi is None: - with open(os.devnull, 'w') as devnull: - old_stdout = sys.stdout - sys.stdout = devnull - try: - import taichi as taichi # noqa - except ModuleNotFoundError: - raise ModuleNotFoundError(taichi_install_info) - finally: - sys.stdout = old_stdout - - if taichi.__version__ != _minimal_taichi_version: - raise RuntimeError(taichi_install_info) - return taichi - - -def is_brainpylib_gpu_installed(): - return False if brainpylib_gpu_ops is None else True - - -def import_brainpylib_cpu_ops(): - global brainpylib_cpu_ops - if brainpylib_cpu_ops is None: - try: - from brainpylib import cpu_ops as brainpylib_cpu_ops - - for _name, _value in brainpylib_cpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="cpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_cpu_ops - - def import_brainpylib_gpu_ops(): global brainpylib_gpu_ops if brainpylib_gpu_ops is None: diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index e4b6e25d2..deead1f3b 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = bm.unsqueeze(x, 0) + x = x.unsqueeze(0) w = self.w.value if self.mask is not None: try: @@ -190,9 +190,6 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. - The input should a 2d array with the shape of ``[H, C]``, or - a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. - Parameters ---------- in_channels: int @@ -285,9 +282,6 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, C]``, or - a 4d array with the shape of ``[B, H, W, C]``. - Parameters ---------- in_channels: int @@ -381,9 +375,6 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, D, C]``, or - a 4d array with the shape of ``[B, H, W, D, C]``. - Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index 17054667d..ba2a49efd 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,6 +1,5 @@ -from absl.testing import absltest from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 05f523622..3c9fdfa87 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,15 +1,17 @@ # -*- coding: utf-8 -*- -import jax.numpy as jnp +from unittest import TestCase from absl.testing import absltest +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import parameterized - import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): + bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -22,7 +24,6 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -30,6 +31,7 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): + bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -37,7 +39,6 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -53,7 +54,6 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,7 +67,6 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 9ad15938d..269fec441 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- +from unittest import TestCase + +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized - import brainpy as bp -import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 7735563a9..422f161f1 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -219,224 +219,5 @@ def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): bm.clear_buffer_memory() -if __name__ == '__main__': - absltest.main() -======= -from absl.testing import absltest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class TestLinear(parameterized.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bm.random.seed() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - num_out=[20, 10, 5] - ) - def test_Dense1(self, size, num_out): - bm.random.seed() - f = bp.dnn.Linear(10, num_out) - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size[:-1] + (num_out,)) - bm.clear_buffer_memory() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - ) - def test_Identity(self, size): - bm.random.seed() - f = bp.dnn.Identity() - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size) - bm.clear_buffer_memory() - - def test_AllToAll1(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((8, 10)) - y = f(x) - expected = bm.sum(x, axis=1, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((10,)) - y = f(x) - expected = bm.sum(x, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - def test_OneToOne(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((8, 10)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((10,)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - # bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_MaskedLinear(self, conn): - bm.random.seed() - bm.random.DEFAULT.seed(123) - f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_CSRLinear(self, conn): - bm.random.seed() - f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_EventCSRLinear(self,conn): - bm.random.seed() - f=bp.layers.EventCSRLinear(conn,weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - if __name__ == '__main__': absltest.main() diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 28af8cc19..f0c67da12 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -803,807 +803,5 @@ def test_InstanceNorm(self, mode): output = layer(input) -if __name__ == '__main__': - absltest.main() -======= -from absl.testing import absltest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class Test_Conv(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv2_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv3_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose2d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose3d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - -class TestPool(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MinPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AvgPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.MaxPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - -class Test_Dropout(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Dropout(self, mode): - bp.share.save(fit=False) - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.Dropout(prob=0.2, - mode=mode) - output = layer(input) - - -class Test_function(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Flatten(self, mode): - bm.random.seed() - layer = bp.dnn.Flatten(mode=mode) - input = bm.random.randn(10, 5, 5, 5, 4) - output = layer(input) - - -class Test_linear(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_linear(self, mode): - bm.random.seed() - input = bm.random.randn(10, 9, 8, 7) - layer = bp.dnn.Linear(num_in=7, - num_out=6, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AllToAll(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.AllToAll(num_pre=10, - num_post=20, - weight=0.1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_OneToOne(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.OneToOne(num=10, - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaskedLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_CSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventCSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - -class Test_Normalization(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm1d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm1d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm2d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm2d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm3d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm3d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 7, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_LayerNorm(self, mode): - bm.random.seed() - layer = bp.dnn.LayerNorm(normalized_shape=3, - mode=mode, - elementwise_affine=False - ) - input = bm.random.randn(10, 5, 3) - outout = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_GroupNorm(self, mode): - bm.random.seed() - layer = bp.dnn.GroupNorm(num_groups=2, - num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_InstanceNorm(self, mode): - bm.random.seed() - layer = bp.dnn.InstanceNorm(num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - if __name__ == '__main__': absltest.main() diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index de2c9765b..fdc5b34e3 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,8 +1,7 @@ -from absl.testing import absltest +import brainpy.math as bm from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp -import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 5748edd8b..34f8f5cd5 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import absltest from absl.testing import parameterized +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 676e4286b..eb8e27c8f 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import broadcast_to, expand_dims, concatenate +from .compat_numpy import vstack, broadcast_to from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,7 +392,6 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: - self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -473,7 +472,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) + self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) else: self.data[:] = value diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index ad8a5ccf6..f5e091675 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,8 +28,10 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, VariableStack) -from .tools import eval_shape +from .variables import (Variable, + VariableStack, + current_transform_number, + new_transform) __all__ = [ 'grad', # gradient of scalar function @@ -201,21 +203,36 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with VariableStack() as stack: - rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs) + with new_transform(self): + with VariableStack() as stack: + if current_transform_number() > 1: + rets = self._transform( + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) + else: + rets = jax.eval_shape( + self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if not stack.is_first_stack(): - return self._return(rets) + # if not the outermost transformation + if current_transform_number(): + return self._return(rets) + else: + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index c52845a06..aaf053ae7 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,6 +6,7 @@ """ import numbers +import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional @@ -13,13 +14,14 @@ import jax import numpy as np -from brainpy._src.math.modes import Mode +from brainpy import errors from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) +from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS variable_ = None diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3edeb08e8..032a0fab6 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,12 +21,17 @@ cache_stack ) from .tools import ( - eval_shape, + evaluate_dyn_vars, dynvar_deprecation, node_deprecation, abstract ) -from .variables import (Variable, VariableStack) +from .variables import ( + Variable, + VariableStack, + new_transform, + current_transform_number, +) __all__ = [ 'make_loop', @@ -537,13 +542,15 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit and dyn_vars is None: - with VariableStack() as dyn_vars: - rets = eval_shape(true_fun, *operands, with_stack=True)[1] - _ = eval_shape(false_fun, *operands, with_stack=True) - cache_stack((true_fun, false_fun), dyn_vars) - if not dyn_vars.is_first_stack(): - return rets + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('cond'): + dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((true_fun, false_fun), dyn_vars) + if current_transform_number() > 0: + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -674,16 +681,20 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with VariableStack() as dyn_vars: - rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + with new_transform('ifelse'): + with VariableStack() as dyn_vars: + if current_transform_number() > 1: + rets = [branch(*operands) for branch in branches] + else: + rets = [jax.eval_shape(branch, *operands) for branch in branches] + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if not dyn_vars.is_first_stack(): + if current_transform_number(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -869,23 +880,28 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - stack = get_stack_cache((body_fun, unroll_kwargs)) + dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if stack is None: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, - remat, reverse, unroll, unroll_kwargs) + if dyn_vars is None: # TODO: better cache mechanism? - with VariableStack() as stack: - rets = eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), stack) # cache - if not stack.is_first_stack(): + with new_transform('for_loop'): + with VariableStack() as dyn_vars: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, + progress_bar, remat, reverse, unroll, + unroll_kwargs) + if current_transform_number() > 1: + rets = transform(operands) + else: + rets = jax.eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache + if current_transform_number(): return rets[1] del rets else: - stack = VariableStack() + dyn_vars = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, stack, bar, + transform = _get_for_loop_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -893,11 +909,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, stack + del dyn_vals, dyn_vars return out_vals @@ -995,21 +1011,26 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - stack = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit and stack is None: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - with VariableStack() as stack: - rets = eval_shape(transform, init, operands) - cache_stack(body_fun, stack) # cache - if not stack.is_first_stack(): - return rets[0][1], rets[1] - del rets - - stack = VariableStack() if stack is None else stack - transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) + dyn_vars = get_stack_cache(body_fun) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('scan'): + with VariableStack() as dyn_vars: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + if current_transform_number() > 1: + rets = transform(init, operands) + else: + rets = jax.eval_shape(transform, init, operands) + cache_stack(body_fun, dyn_vars) # cache + if current_transform_number(): + return rets[0][1], rets[1] + del rets + + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1108,6 +1129,7 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. + """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1115,16 +1137,18 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - stack = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit and stack is None: - with VariableStack() as stack: - _ = eval_shape(cond_fun, *operands, with_stack=True) - rets = eval_shape(body_fun, *operands, with_stack=True)[1] - cache_stack((body_fun, cond_fun), stack) - if not stack.is_first_stack(): - return rets - stack = VariableStack() if stack is None else stack - dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) - for k, v in stack.items(): + dyn_vars = get_stack_cache((body_fun, cond_fun)) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('while_loop'): + dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((body_fun, cond_fun), dyn_vars) + if current_transform_number(): + return rets + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) + for k, v in dyn_vars.items(): v._value = dyn_values[k] return out diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 73eab2f91..7bb36f4e2 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,15 +11,23 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax +from jax.sharding import Sharding from brainpy import tools, check -from .base import BrainPyObject, ObjectTransform -from .naming import get_stack_cache, cache_stack from .tools import (dynvar_deprecation, node_deprecation, - eval_shape) -from .variables import (Variable, VariableStack) + evaluate_dyn_vars_with_cache, + evaluate_dyn_vars, + _partial_fun) +from .base import BrainPyObject, ObjectTransform +from .naming import get_stack_cache, cache_stack from ..ndarray import Array +from .variables import (Variable, + VariableStack, + outermost_transform, + transform_stack, + current_transform_number, + new_transform) RandomState = None @@ -143,12 +151,16 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): return changes, out def _get_transform(self, *args, **kwargs): - with VariableStack() as self._dyn_vars: - rets = eval_shape(self.fun, - *args, - **kwargs, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames) + with new_transform(self): + self._dyn_vars, rets = evaluate_dyn_vars( + self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + use_eval_shape=current_transform_number() <= 1, + **kwargs + ) + # in_shardings if self._in_shardings is None: in_shardings = None @@ -174,18 +186,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): @@ -195,7 +207,7 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if not self._dyn_vars.is_first_stack(): + if current_transform_number(): return rets # call the transformed function @@ -465,8 +477,15 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - with VariableStack() as stack: - _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + + with jax.ensure_compile_time_eval(): + if len(static_argnums) or len(static_argnames): + fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) + else: + args_, kwargs_, fun3 = args, kwargs, fun2 + with VariableStack() as stack: + _ = jax.eval_shape(fun3, *args_, **kwargs_) + del args_, kwargs_ _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1181e003b..1c8ca6ef9 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -41,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=True): +def clear_name_cache(ignore_warn=False): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -57,7 +57,6 @@ def cache_stack(func, stack): def clear_stack_cache(): - """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py new file mode 100644 index 000000000..1eddce048 --- /dev/null +++ b/brainpy/_src/math/object_transform/parallels.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- + +""" +The parallel compilation tools for JAX backend. + +1. Vectorize compilation is implemented by the 'vmap()' function +2. Parallel compilation is implemented by the 'pmap()' function + +""" + + +import functools + +import jax +import jax.numpy as jnp +import numpy as np +from jax.interpreters.partial_eval import DynamicJaxprTracer +from jax.interpreters.partial_eval import JaxprTracer +from jax.interpreters.pxla import ShardedDeviceArray + +try: + from jax.errors import UnexpectedTracerError +except ImportError: + from jax.core import UnexpectedTracerError + +from brainpy import errors +from brainpy._src.math.random import RandomState +from brainpy._src.math.ndarray import Array +from brainpy.tools import change_func_name +from .base import BrainPyObject, ArrayCollector + +__all__ = [ + 'vmap', + 'pmap', +] + + +def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, + batch_idx, axis_name, f_name=None): + @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) + def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + out = func(*args, **kwargs) + nonbatched_changes = nonbatched_vars.dict() + batched_changes = batched_vars.dict() + return nonbatched_changes, batched_changes, out + + def call(*args, **kwargs): + n = args[batch_idx[0]].shape[batch_idx[1]] + nonbatched_data = nonbatched_vars.dict() + batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} + try: + out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) + except UnexpectedTracerError as e: + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + raise errors.JaxTracerError() from e + # for key, v in dyn_changes.items(): + # dyn_vars[key] = reduce_func(v) + # for key, v in rand_changes.items(): + # rand_vars[key] = reduce_func(v) + return out + + return change_func_name(name=f_name, f=call) if f_name else call + + +def vmap(func, dyn_vars=None, batched_vars=None, + in_axes=0, out_axes=0, axis_name=None, + reduce_func=None, auto_infer=False): + """Vectorization compilation for class objects. + + Vectorized compile a function or a module to run in parallel on a single device. + + Examples + -------- + + Parameters + ---------- + func : BrainPyObject, function, callable + The function or the module to compile. + dyn_vars : dict, sequence + batched_vars : dict + in_axes : optional, int, sequence of int + Specify which input array axes to map over. If each positional argument to + ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, + or a tuple of integers and Nones with length equal to the number of + positional arguments to ``obj_or_func``. An integer or ``None`` + indicates which array axis to map over for all arguments (with ``None`` + indicating not to map any axis), and a tuple indicates which axis to map + for each corresponding positional argument. Axis integers must be in the + range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of + dimensions (axes) of the corresponding input array. + + If the positional arguments to ``obj_or_func`` are container types, the + corresponding element of ``in_axes`` can itself be a matching container, + so that distinct array axes can be mapped for different container + elements. ``in_axes`` must be a container tree prefix of the positional + argument tuple passed to ``obj_or_func``. + + At least one positional argument must have ``in_axes`` not None. The sizes + of the mapped input axes for all mapped positional arguments must all be + equal. + + Arguments passed as keywords are always mapped over their leading axis + (i.e. axis index 0). + out_axes : optional, int, tuple/list/dict + Indicate where the mapped axis should appear in the output. All outputs + with a mapped axis must have a non-None ``out_axes`` specification. Axis + integers must be in the range ``[-ndim, ndim)`` for each output array, + where ``ndim`` is the number of dimensions (axes) of the array returned + by the :func:`vmap`-ed function, which is one more than the number of + dimensions (axes) of the corresponding array returned by ``obj_or_func``. + axis_name : optional + + Returns + ------- + obj_or_func : Any + Batched/vectorized version of ``obj_or_func`` with arguments that correspond to + those of ``obj_or_func``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but + with extra array axes at positions indicated by ``out_axes``. + + """ + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # dyn_vars = (dyn_vars or func.vars().unique()) + # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() + # for key, val in dyn_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # in axes + # if in_axes is None: + # in_axes = {key: (None, 0) for key in func.steps.keys()} + # elif isinstance(in_axes, int): + # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, (tuple, list)): + # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in in_axes: + # in_axes = {key: (None, 0, in_axes) for key in keys} + # else: + # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} + # assert isinstance(in_axes, dict) + # + # # batch size index + # batch_idx = {} + # for key, axes in in_axes.items(): + # for i, axis in enumerate(axes[2:]): + # if axis is not None: + # batch_idx[key] = (i, axis) + # break + # else: + # raise ValueError(f'Found no batch axis: {axes}.') + # + # # out axes + # if out_axes is None: + # out_axes = {key: 0 for key in func.steps.keys()} + # elif isinstance(out_axes, int): + # out_axes = {key: out_axes for key in func.steps.keys()} + # elif isinstance(out_axes, (tuple, list)): + # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} + # elif isinstance(out_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in out_axes: + # out_axes = {key: (out_axes, 0, 0) for key in keys} + # else: + # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} + # assert isinstance(out_axes, dict) + # + # # reduce_func + # if reduce_func is None: + # reduce_func = lambda x: x.mean(axis=0) + # + # # vectorized map functions + # for key in func.steps.keys(): + # func.steps[key] = _make_vmap(func=func.steps[key], + # dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # in_axes=in_axes[key], + # out_axes=out_axes[key], + # axis_name=axis_name, + # batch_idx=batch_idx[key], + # reduce_func=reduce_func, + # f_name=key) + # + # return func + + if callable(func): + if auto_infer: + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.vmap(func, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name) + + else: + if isinstance(dyn_vars, Array): + dyn_vars = [dyn_vars] + if isinstance(dyn_vars, (tuple, list)): + dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} + assert isinstance(dyn_vars, dict) + + # dynamical variables + _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + _rand_vars[key] = val + else: + _dyn_vars[key] = val + + # in axes + if in_axes is None: + in_axes = (None, 0) + elif isinstance(in_axes, (int, dict)): + in_axes = (None, 0, in_axes) + elif isinstance(in_axes, (tuple, list)): + in_axes = (None, 0) + tuple(in_axes) + assert isinstance(in_axes, (tuple, list)) + + # batch size index + batch_idx = {} + for key, axes in batch_idx.items(): + for i, axis in enumerate(axes[2:]): + if axis is not None: + batch_idx[key] = (i, axis) + break + else: + raise ValueError(f'Found no batch axis: {axes}.') + + # out axes + if out_axes is None: + out_axes = 0 + elif isinstance(out_axes, (int, dict)): + out_axes = (out_axes, 0, 0) + elif isinstance(out_axes, (tuple, list)): + out_axes = tuple(out_axes) + (0, 0) + assert isinstance(out_axes, (list, tuple)) + + # reduce_func + if reduce_func is None: + reduce_func = lambda x: x.mean(axis=0) + + # jit function + return _make_vmap(func=func, + nonbatched_vars=_dyn_vars, + batched_vars=_rand_vars, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + batch_idx=batch_idx) + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' + f'function, but we got {type(func)}.') + + +def _device_reshape(x): + """Reshape an input array in order to broadcast to multiple devices.""" + num_device = jax.local_device_count() + + if not hasattr(x, 'ndim'): + raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' + f'parallel, first convert it to a Array, for example np.float(0.5)') + if x.ndim == 0: + return np.broadcast_to(x, [num_device]) + if x.shape[0] % num_device != 0: + raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' + f'{num_device} devices, but does not go equally.') + return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) + + +def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, + out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, + axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): + @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, + static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, + backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + def pmapped_func(dyn_data, rand_data, *args, **kwargs): + dyn_vars.assign(dyn_data) + rand_vars.assign(rand_data) + out = func(*args, **kwargs) + dyn_changes = dyn_vars.dict() + rand_changes = rand_vars.dict() + return out, dyn_changes, rand_changes + + def call(*args): + un_replicated = [k for k, v in dyn_vars.items() + if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] + if len(un_replicated): + raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' + f'did you forget to call xx.replicate() on them?') + _args = [] + for i, x in enumerate(args): + if i + 2 in static_broadcasted_argnums: + _args.append(x) + else: + _args.append(jax.tree_map(_device_reshape, [x])[0]) + dyn_data = dyn_vars.dict() + rand_data = rand_vars.dict() + output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) + dyn_vars.assign(dyn_changes) + rand_vars.assign(rand_changes) + return jax.tree_map(reduce_func, output) + + return change_func_name(name=f_name, f=call) if f_name else call + + +def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), + devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, + reduce_func=None): + """Parallel compilation for class objects. + + Parallel compile a function or a module to run on multiple devices in parallel. + + Parameters + ---------- + func + axis_name + in_axes + out_axes + static_broadcasted_argnums + devices + backend + axis_size + donate_argnums + global_arg_shapes + + Returns + ------- + + + Examples + -------- + + + """ + + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # all_vars = (dyn_vars or func.vars().unique()) + # dyn_vars = ArrayCollector() + # rand_vars = ArrayCollector() + # for key, val in all_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # reduce function + # if reduce_func is None: + # reduce_func = jnp.concatenate + # + # # static broadcast-ed arguments + # if static_broadcasted_argnums is None: + # static_broadcasted_argnums = () + # elif isinstance(static_broadcasted_argnums, int): + # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + # elif isinstance(static_broadcasted_argnums, (tuple, list)): + # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + # assert isinstance(static_broadcasted_argnums, (tuple, list)) + # + # # jit functions + # for key in func.steps.keys(): + # step = func.steps[key] + # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # func=step, + # axis_name=axis_name, + # in_axes=in_axes, + # out_axes=out_axes, + # static_broadcasted_argnums=static_broadcasted_argnums, + # devices=devices, + # backend=backend, + # axis_size=axis_size, + # donate_argnums=donate_argnums, + # global_arg_shapes=global_arg_shapes, + # reduce_func=reduce_func, + # f_name=key) + # return func + + if callable(func): + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.pmap(func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + else: + # dynamical variables + dyn_vars = ArrayCollector() + rand_vars = ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + rand_vars[key] = val + else: + dyn_vars[key] = val + + # static broadcast-ed arguments + if static_broadcasted_argnums is None: + static_broadcasted_argnums = () + elif isinstance(static_broadcasted_argnums, int): + static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + elif isinstance(static_broadcasted_argnums, (tuple, list)): + static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + assert isinstance(static_broadcasted_argnums, (tuple, list)) + + # reduce function + if reduce_func is None: + reduce_func = jnp.concatenate + + # jit function + func.__call__ = _make_pmap(dyn_vars=dyn_vars, + rand_vars=rand_vars, + func=func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + reduce_func=reduce_func) + return func + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' + f'but we got {type(func)}.') diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 632c6d79e..7b519590a 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,65 +132,19 @@ def evaluate_dyn_vars_with_cache( return stack -def _partial_fun2( - fun: Callable, - args: tuple, - kwargs: dict, - static_argnums: Sequence[int] = (), - static_argnames: Sequence[str] = () -): - num_args = len(args) - - # arguments - static_args = dict() - dyn_args = [] - dyn_arg_ids = dict() - static_argnums = list(static_argnums) - dyn_i = 0 - for i in range(num_args): - if i in static_argnums: - static_argnums.remove(i) - static_args[i] = args[i] - else: - dyn_args.append(args[i]) - dyn_arg_ids[i] = dyn_i - dyn_i += 1 - if len(static_argnums) > 0: - raise ValueError(f"Invalid static_argnums: {static_argnums}") - - # keyword arguments - static_kwargs, dyn_kwargs = {}, {} - for k, arg in kwargs.items(): - if k in static_argnames: - static_kwargs[k] = arg - else: - dyn_kwargs[k] = arg - del args, kwargs, static_argnums, static_argnames - - @wraps(fun) - def new_fun(*dynargs, **dynkwargs): - return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], - **static_kwargs, - **dynkwargs) - - return new_fun, dyn_args, dyn_kwargs - - def eval_shape( fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), - with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. Args: fun: The callable function. - *args: The positional arguments. - **kwargs: The keyword arguments. - with_stack: Whether evaluate the function within a local variable stack. + *args: + **kwargs: static_argnums: The static argument indices. static_argnames: The static argument names. @@ -199,30 +153,21 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + f2, args, kwargs = _partial_fun(fun, args, kwargs, + static_argnums=static_argnums, + static_argnames=static_argnames) else: - f2 = fun + f2, args, kwargs = fun, args, kwargs # evaluate the function fun_in_eval_shape.append(fun) try: - if with_stack: + with jax.ensure_compile_time_eval(): with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) + returns = fun(*args, **kwargs) else: - returns = jax.eval_shape(f2, *args, **kwargs) - else: - stack = None - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) - else: - returns = jax.eval_shape(f2, *args, **kwargs) + returns = jax.eval_shape(fun, *args, **kwargs) finally: fun_in_eval_shape.pop() - del f2 - if with_stack: - return stack, returns - else: - return returns - + return stack, returns diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index b7babae8d..5014da0bf 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -189,14 +190,6 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id - @classmethod - def num_of_stack(self): - return len(var_stack_list) - - @classmethod - def is_first_stack(self): - return len(var_stack_list) == 0 - def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -217,6 +210,42 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] +transform_stack: List[Callable] = [] + + +@contextmanager +def new_transform(transform: Any): + transform_stack.append(transform) + try: + yield + finally: + transform_stack.pop() + + +def outermost_stack(): + if len(var_stack_list): + return var_stack_list[0] + else: + return None + + +def outermost_transform(): + if len(transform_stack): + return transform_stack[0] + else: + return None + + +def current_transform_number(): + return len(transform_stack) + + +def _stack_add_read(var: 'Variable'): + pass + + +def _stack_add_write(var: 'Variable'): + pass @register_pytree_node_class diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py deleted file mode 100644 index cbc710dba..000000000 --- a/brainpy/_src/tools/functions.py +++ /dev/null @@ -1,192 +0,0 @@ -import inspect -from functools import partial -from operator import attrgetter -from types import MethodType - -__all__ = [ - 'compose', 'pipe' -] - - -def identity(x): - """ Identity function. Return x - - >>> identity(3) - 3 - """ - return x - - -def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): - """ Like @property, but returns ``classval`` when used as a class attribute - - >>> class MyClass(object): - ... '''The class docstring''' - ... @instanceproperty(classval=__doc__) - ... def __doc__(self): - ... return 'An object docstring' - ... @instanceproperty - ... def val(self): - ... return 42 - ... - >>> MyClass.__doc__ - 'The class docstring' - >>> MyClass.val is None - True - >>> obj = MyClass() - >>> obj.__doc__ - 'An object docstring' - >>> obj.val - 42 - """ - if fget is None: - return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, - classval=classval) - return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, - classval=classval) - - -class InstanceProperty(property): - """ Like @property, but returns ``classval`` when used as a class attribute - - Should not be used directly. Use ``instanceproperty`` instead. - """ - - def __init__(self, fget=None, fset=None, fdel=None, doc=None, - classval=None): - self.classval = classval - property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) - - def __get__(self, obj, type=None): - if obj is None: - return self.classval - return property.__get__(self, obj, type) - - def __reduce__(self): - state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) - return InstanceProperty, state - - -class Compose(object): - """ A composition of functions - - See Also: - compose - """ - __slots__ = 'first', 'funcs' - - def __init__(self, funcs): - funcs = tuple(reversed(funcs)) - self.first = funcs[0] - self.funcs = funcs[1:] - - def __call__(self, *args, **kwargs): - ret = self.first(*args, **kwargs) - for f in self.funcs: - ret = f(ret) - return ret - - def __getstate__(self): - return self.first, self.funcs - - def __setstate__(self, state): - self.first, self.funcs = state - - @instanceproperty(classval=__doc__) - def __doc__(self): - def composed_doc(*fs): - """Generate a docstring for the composition of fs. - """ - if not fs: - # Argument name for the docstring. - return '*args, **kwargs' - - return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) - - try: - return ( - 'lambda *args, **kwargs: ' + - composed_doc(*reversed((self.first,) + self.funcs)) - ) - except AttributeError: - # One of our callables does not have a `__name__`, whatever. - return 'A composition of functions' - - @property - def __name__(self): - try: - return '_of_'.join( - (f.__name__ for f in reversed((self.first,) + self.funcs)) - ) - except AttributeError: - return type(self).__name__ - - def __repr__(self): - return '{.__class__.__name__}{!r}'.format( - self, tuple(reversed((self.first,) + self.funcs))) - - def __eq__(self, other): - if isinstance(other, Compose): - return other.first == self.first and other.funcs == self.funcs - return NotImplemented - - def __ne__(self, other): - equality = self.__eq__(other) - return NotImplemented if equality is NotImplemented else not equality - - def __hash__(self): - return hash(self.first) ^ hash(self.funcs) - - # Mimic the descriptor behavior of python functions. - # i.e. let Compose be called as a method when bound to a class. - # adapted from - # docs.python.org/3/howto/descriptor.html#functions-and-methods - def __get__(self, obj, objtype=None): - return self if obj is None else MethodType(self, obj) - - # introspection with Signature is only possible from py3.3+ - @instanceproperty - def __signature__(self): - base = inspect.signature(self.first) - last = inspect.signature(self.funcs[-1]) - return base.replace(return_annotation=last.return_annotation) - - __wrapped__ = instanceproperty(attrgetter('first')) - - -def compose(*funcs): - """ Compose functions to operate in series. - - Returns a function that applies other functions in sequence. - - Functions are applied from right to left so that - ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. - - If no arguments are provided, the identity function (f(x) = x) is returned. - - >>> inc = lambda i: i + 1 - >>> compose(str, inc)(3) - '4' - """ - if not funcs: - return identity - if len(funcs) == 1: - return funcs[0] - else: - return Compose(funcs) - - -def pipe(*funcs): - """ Pipe a value through a sequence of functions - - I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` - - We think of the value as progressing through a pipe of several - transformations, much like pipes in UNIX - - - >>> double = lambda i: 2 * i - >>> pipe(double, str)(3) - '6' - """ - return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py deleted file mode 100644 index c285e561a..000000000 --- a/brainpy/_src/tools/tests/test_functions.py +++ /dev/null @@ -1,24 +0,0 @@ - -import unittest - -import brainpy as bp -import brainpy.math as bm - - -class TestFunction(unittest.TestCase): - def test_compose(self): - f = lambda a: a + 1 - g = lambda a: a * 10 - fun1 = bp.tools.compose(f, g) - fun2 = bp.tools.pipe(g, f) - - arr = bm.random.randn(10) - r1 = fun1(arr) - r2 = fun2(arr) - groundtruth = f(g(arr)) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, groundtruth)) - bm.clear_buffer_memory() - - - diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 3b0c3f517..e4570f6fd 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - # add as add, + add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 7654731d8..548a987d0 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -59,7 +59,3 @@ eval_shape as eval_shape, ) -from brainpy._src.math.object_transform.variables import ( - VariableStack as VariableStack, -) - diff --git a/brainpy/tools.py b/brainpy/tools.py index 233269dc5..0f3a4c0ef 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -45,9 +45,4 @@ ) -from brainpy._src.tools.functions import ( - compose as compose, - pipe as pipe, -) - diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 0b78315ab..5c8cba0fd 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -3,52 +3,13 @@ Advanced Tutorials This section contains tutorials that illustrate more advanced features of BrainPy. -Advanced Math -------------- .. toctree:: - :maxdepth: 1 - - tutorial_advanced/compilation.ipynb - tutorial_advanced/differentiation.ipynb - - -Interoperation --------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/integrate_flax_into_brainpy.ipynb - tutorial_advanced/integrate_bp_lif_into_flax.ipynb - tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb - - -Brain Dynamics Dedicated Operators ----------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/operator_custom_with_numba.ipynb - tutorial_advanced/operator_custom_with_taichi.ipynb - - -Developer Guides ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/contributing.md - - -Others ------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/advanced_lowdim_analysis.ipynb + :maxdepth: 2 + tutorial_advanced/1_advanced_math.rst + tutorial_advanced/2_interoperation.rst + tutorial_advanced/3_dedicated_operators.rst + tutorial_advanced/4_developer_guides.rst + tutorial_advanced/5_others.rst diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 9ed9cf46a..754e0d81d 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,5 +77,4 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape - VariableStack diff --git a/docs/toolboxes.rst b/docs/toolboxes.rst index cc3a38575..11bf53115 100644 --- a/docs/toolboxes.rst +++ b/docs/toolboxes.rst @@ -1,16 +1,7 @@ BDP Toolboxes ================== - - - This section contains detailed toolboxes BrainPy uses for brain dynamics modeling. - - -Differential Equations ------------------------ - - .. toctree:: :maxdepth: 1 @@ -19,34 +10,11 @@ Differential Equations tutorial_toolbox/fde_numerical_solvers tutorial_toolbox/dde_numerical_solvers tutorial_toolbox/joint_equations - - -Toolbox for Modeling -------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/synaptic_connections tutorial_toolbox/synaptic_weights - tutorial_toolbox/inputs - - -Toolbox for Training --------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/optimizers + tutorial_toolbox/state_saving_and_loading.ipynb + tutorial_toolbox/state_resetting.ipynb tutorial_toolbox/surrogate_gradient + tutorial_toolbox/inputs - -State Resetting, Saving and Loading ------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_toolbox/state_saving_and_loading.ipynb - tutorial_toolbox/state_resetting.ipynb \ No newline at end of file diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 57d18332b..7c9a1c876 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -3,76 +3,11 @@ BDP Tutorials This section contains tutorials on how to use BrainPy to accomplish model building, simulation, training, and analysis. - -Math Foundation ---------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_math/variables - tutorial_math/control_flows - tutorial_math/Numpy_like_Operations.ipynb - tutorial_math/Dedicated_Operators.ipynb - tutorial_math/einops_in_brainpy.ipynb - - -Model Building with Existing Modules ------------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_building/overview_of_dynamic_model - tutorial_building/build_conductance_neurons_v2.ipynb - tutorial_building/phenon_synapse_models.ipynb - tutorial_building/kinetic_synapse_models.ipynb - tutorial_building/build_network_models - - -Model Building by Customizing New Modules ------------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_building/customize_neuron_models - tutorial_building/customize_synapse_models - tutorial_building/how_to_customze_a_synapse.ipynb - - -Model Simulation ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_simulation/simulation_dsrunner.ipynb - tutorial_simulation/parallel_for_parameter_exploration.ipynb - tutorial_simulation/monitor_per_multiple_steps.ipynb - - -Model Training --------------- - -This tutorial shows how to train a dynamical system from data or task. - -.. toctree:: - :maxdepth: 1 - - tutorial_training/build_training_models.ipynb - tutorial_training/offline_training.ipynb - tutorial_training/online_training.ipynb - tutorial_training/bp_training.ipynb - tutorial_training/esn_introduction.ipynb - - -Model Analysis --------------- - .. toctree:: - :maxdepth: 1 + :maxdepth: 2 - tutorial_analysis/lowdim_analysis - tutorial_analysis/highdim_analysis - tutorial_analysis/decision_making_model + tutorial_math/index + tutorial_building/index + tutorial_simulation/index + tutorial_training/index + tutorial_analysis/index diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index 9c7daff55..f98527458 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('delay') + spk = self.delay.at('I') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input)) From 456731b9a672b946c4d68fc95db7e7b29c25f41a Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 29 Feb 2024 13:40:01 +0800 Subject: [PATCH 11/16] upgrade dependency --- brainpy/_src/connect/random_conn.py | 285 ++++--------- brainpy/_src/dependency_check.py | 52 +-- brainpy/_src/dnn/linear.py | 390 +++++++++--------- .../_src/dyn/projections/tests/test_STDP.py | 11 +- brainpy/_src/math/event/_csr_matvec.py | 47 +-- brainpy/_src/math/jitconn/_event_matvec.py | 228 +--------- brainpy/_src/math/jitconn/_matvec.py | 362 ++++------------ brainpy/_src/math/op_register/base.py | 34 +- .../op_register/numba_approach/__init__.py | 14 +- .../numba_approach/cpu_translation.py | 17 +- brainpy/_src/math/op_register/numba_based.py | 11 +- brainpy/_src/math/sparse/_bsr_mm.py | 4 +- brainpy/_src/math/sparse/_csr_mv.py | 55 +-- brainpy/errors.py | 14 +- 14 files changed, 432 insertions(+), 1092 deletions(-) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 9438c3306..0e4ee769c 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- + from functools import partial from typing import Optional @@ -9,10 +10,8 @@ from brainpy.errors import ConnectorError from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed from brainpy._src.tools.package import SUPPORT_NUMBA -from brainpy._src.dependency_check import import_numba from .base import * -numba = import_numba(error_if_not_found=False) __all__ = [ 'FixedProb', @@ -1099,171 +1098,43 @@ def __init__(self, dist=1, prob=1., pre_ratio=1., seed=None, include_self=True, rng = np.random if SUPPORT_NUMBA else self.rng - # @njit(parallel=True) - # def _connect_1d_jit_parallel(pre_pos, pre_size, post_size, n_dim): - # all_post_ids = np.zeros(post_size[0], dtype=get_idx_type()) - # all_pre_ids = np.zeros(post_size[0], dtype=get_idx_type()) - # size = 0 - # - # if rng.random() < pre_ratio: - # normalized_pos = np.zeros(n_dim) - # for i in prange(n_dim): # Use prange for potential parallelism - # pre_len = pre_size[i] - # post_len = post_size[i] - # normalized_pos[i] = pre_pos[i] * post_len / pre_len - # for i in prange(post_size[0]): - # post_pos = np.asarray((i,)) - # d = np.abs(pre_pos[0] - post_pos[0]) # Adjust the distance calculation - # if d <= dist: - # if d == 0. and not include_self: - # continue - # if rng.random() <= prob: - # all_post_ids[size] = pos2ind(post_pos, post_size) - # all_pre_ids[size] = pos2ind(pre_pos, pre_size) - # size += 1 - # return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays - - if numba is not None: - from numba import njit - @njit - def _connect_1d_jit(pre_pos, pre_size, post_size, n_dim): - all_post_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) - all_pre_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - post_pos = np.asarray((i,)) - d = np.abs(pre_pos[0] - post_pos[0]) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] - - @njit - def _connect_2d_jit(pre_pos, pre_size, post_size, n_dim): - max_size = post_size[0] * post_size[1] - all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) - all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - for j in range(post_size[1]): - post_pos = np.asarray((i, j)) - d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays - - @njit - def _connect_3d_jit(pre_pos, pre_size, post_size, n_dim): - max_size = post_size[0] * post_size[1] * post_size[2] - all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) - all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - for j in range(post_size[1]): - for k in range(post_size[2]): - post_pos = np.asarray((i, j, k)) - d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] - - @njit - def _connect_4d_jit(pre_pos, pre_size, post_size, n_dim): - max_size = post_size[0] * post_size[1] * post_size[2] * post_size[3] - all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) - all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) - size = 0 - - if rng.random() < pre_ratio: - normalized_pos = np.zeros(n_dim) - for i in range(n_dim): - pre_len = pre_size[i] - post_len = post_size[i] - normalized_pos[i] = pre_pos[i] * post_len / pre_len - for i in range(post_size[0]): - for j in range(post_size[1]): - for k in range(post_size[2]): - for l in range(post_size[3]): - post_pos = np.asarray((i, j, k, l)) - d = np.sqrt(np.sum(np.square(pre_pos - post_pos))) - if d <= dist: - if d == 0. and not include_self: - continue - if rng.random() <= prob: - all_post_ids[size] = pos2ind(post_pos, post_size) - all_pre_ids[size] = pos2ind(pre_pos, pre_size) - size += 1 - return all_pre_ids[:size], all_post_ids[:size] - - self._connect_1d_jit = _connect_1d_jit - self._connect_2d_jit = _connect_2d_jit - self._connect_3d_jit = _connect_3d_jit - self._connect_4d_jit = _connect_4d_jit - - def _connect_1d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] + @numba_jit + def _connect_1d_jit(pre_pos, pre_size, post_size, n_dim): + all_post_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) + all_pre_ids = np.zeros(post_size[0], dtype=IDX_DTYPE) + size = 0 + if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): post_pos = np.asarray((i,)) - d = np.sum(np.abs(pre_pos - post_pos)) + d = np.abs(pre_pos[0] - post_pos[0]) if d <= dist: if d == 0. and not include_self: continue if rng.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] + + @numba_jit + def _connect_2d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 - def _connect_2d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): for j in range(post_size[1]): post_pos = np.asarray((i, j)) @@ -1271,20 +1142,25 @@ def _connect_2d(pre_pos, pre_size, post_size, n_dim): if d <= dist: if d == 0. and not include_self: continue - if np.random.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids - - def _connect_3d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] # Return filled part of the arrays + + @numba_jit + def _connect_3d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] * post_size[2] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 + if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): for j in range(post_size[1]): for k in range(post_size[2]): @@ -1293,20 +1169,25 @@ def _connect_3d(pre_pos, pre_size, post_size, n_dim): if d <= dist: if d == 0. and not include_self: continue - if np.random.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids - - def _connect_4d(pre_pos, pre_size, post_size, n_dim): - all_post_ids = [] - all_pre_ids = [] + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] + + @numba_jit + def _connect_4d_jit(pre_pos, pre_size, post_size, n_dim): + max_size = post_size[0] * post_size[1] * post_size[2] * post_size[3] + all_post_ids = np.zeros(max_size, dtype=IDX_DTYPE) + all_pre_ids = np.zeros(max_size, dtype=IDX_DTYPE) + size = 0 + if rng.random() < pre_ratio: - normalized_pos = [] + normalized_pos = np.zeros(n_dim) for i in range(n_dim): pre_len = pre_size[i] post_len = post_size[i] - normalized_pos.append(pre_pos[i] * post_len / pre_len) + normalized_pos[i] = pre_pos[i] * post_len / pre_len for i in range(post_size[0]): for j in range(post_size[1]): for k in range(post_size[2]): @@ -1316,15 +1197,16 @@ def _connect_4d(pre_pos, pre_size, post_size, n_dim): if d <= dist: if d == 0. and not include_self: continue - if np.random.random() <= prob: - all_post_ids.append(pos2ind(post_pos, post_size)) - all_pre_ids.append(pos2ind(pre_pos, pre_size)) - return all_pre_ids, all_post_ids + if rng.random() <= prob: + all_post_ids[size] = pos2ind(post_pos, post_size) + all_pre_ids[size] = pos2ind(pre_pos, pre_size) + size += 1 + return all_pre_ids[:size], all_post_ids[:size] - self._connect_1d = numba_jit(_connect_1d) - self._connect_2d = numba_jit(_connect_2d) - self._connect_3d = numba_jit(_connect_3d) - self._connect_4d = numba_jit(_connect_4d) + self._connect_1d_jit = _connect_1d_jit + self._connect_2d_jit = _connect_2d_jit + self._connect_3d_jit = _connect_3d_jit + self._connect_4d_jit = _connect_4d_jit def build_coo(self, isOptimized=True): if len(self.pre_size) != len(self.post_size): @@ -1336,41 +1218,16 @@ def build_coo(self, isOptimized=True): # connections n_dim = len(self.pre_size) - if not isOptimized: - if n_dim == 1: - f = self._connect_1d - elif n_dim == 2: - f = self._connect_2d - elif n_dim == 3: - f = self._connect_3d - elif n_dim == 4: - f = self._connect_4d - else: - raise NotImplementedError('Does not support the network dimension bigger than 4.') + if n_dim == 1: + f = self._connect_1d_jit + elif n_dim == 2: + f = self._connect_2d_jit + elif n_dim == 3: + f = self._connect_3d_jit + elif n_dim == 4: + f = self._connect_4d_jit else: - if numba is None: - if n_dim == 1: - f = self._connect_1d - elif n_dim == 2: - f = self._connect_2d - elif n_dim == 3: - f = self._connect_3d - elif n_dim == 4: - f = self._connect_4d - else: - raise NotImplementedError('Does not support the network dimension bigger than 4.') - else: - if n_dim == 1: - f = self._connect_1d_jit - elif n_dim == 2: - f = self._connect_2d_jit - elif n_dim == 3: - f = self._connect_3d_jit - elif n_dim == 4: - f = self._connect_4d_jit - else: - raise NotImplementedError('Does not support the network dimension bigger than 4.') - + raise NotImplementedError('Does not support the network dimension bigger than 4.') pre_size = np.asarray(self.pre_size) post_size = np.asarray(self.post_size) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 2babb5023..2820c7081 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,4 +1,3 @@ -import functools import os import sys @@ -7,12 +6,8 @@ __all__ = [ 'import_taichi', 'raise_taichi_not_found', - 'check_taichi_func', - 'check_taichi_class', 'import_numba', 'raise_numba_not_found', - 'check_numba_func', - 'check_numba_class', 'import_brainpylib_cpu_ops', 'import_brainpylib_gpu_ops', ] @@ -20,8 +15,8 @@ _minimal_brainpylib_version = '0.2.6' _minimal_taichi_version = (1, 7, 0) -taichi = None numba = None +taichi = None brainpylib_cpu_ops = None brainpylib_gpu_ops = None @@ -29,8 +24,7 @@ f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' '> pip install taichi==1.7.0') numba_install_info = ('We need numba. Please install numba by pip . \n' - '> pip install numba' - ) + '> pip install numba') os.environ["TI_LOG_LEVEL"] = "error" @@ -55,30 +49,10 @@ def import_taichi(error_if_not_found=True): return taichi -def raise_taichi_not_found(): +def raise_taichi_not_found(*args, **kwargs): raise ModuleNotFoundError(taichi_install_info) -def check_taichi_func(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if taichi is None: - raise_taichi_not_found() - return func(*args, **kwargs) - - return wrapper - - -def check_taichi_class(cls): - class Wrapper(cls): - def __init__(self, *args, **kwargs): - if taichi is None: - raise_taichi_not_found() - super().__init__(*args, **kwargs) - - return Wrapper - - def import_numba(error_if_not_found=True): global numba if numba is None: @@ -96,26 +70,6 @@ def raise_numba_not_found(): raise ModuleNotFoundError(numba_install_info) -def check_numba_func(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - if numba is None: - raise_numba_not_found() - return func(*args, **kwargs) - - return wrapper - - -def check_numba_class(cls): - class Wrapper(cls): - def __init__(self, *args, **kwargs): - if numba is None: - raise_numba_not_found() - super().__init__(*args, **kwargs) - - return Wrapper - - def is_brainpylib_gpu_installed(): return False if brainpylib_gpu_ops is None else True diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 7a92bc8b2..c524fb0bf 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -11,17 +11,16 @@ from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share +from brainpy._src.dependency_check import import_taichi from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy._src.dependency_check import import_numba, import_taichi, check_numba_func, check_taichi_func from brainpy.check import is_initializer from brainpy.connect import csr2csc -from brainpy.errors import MathError +from brainpy.errors import MathError, PackageMissingError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding ti = import_taichi(error_if_not_found=False) -numba = import_numba(error_if_not_found=False) __all__ = [ 'Dense', 'Linear', @@ -239,152 +238,108 @@ def update(self, x): return x -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) - -dense_on_pre_prim = None if ti is not None: + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) + @ti.kernel - def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] + def _dense_on_post( + old_w: ti.types.ndarray(ndim=2), + post_spike: ti.types.ndarray(ndim=1), + pre_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: - out_w[i, j] = new_value + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if post_spike[j]: + new_value = out_w[i, j] + pre_trace[i] + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] + + + dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) @ti.kernel - def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] + def _dense_on_pre( + old_w: ti.types.ndarray(ndim=2), + pre_spike: ti.types.ndarray(ndim=1), + post_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 - if new_value < w_min0: - out_w[i, j] = w_min0 - elif new_value > w_max0: - out_w[i, j] = w_max0 - else: - out_w[i, j] = new_value + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if pre_spike[i]: + new_value = out_w[i, j] + post_trace[j] + if new_value < w_min0: + out_w[i, j] = w_min0 + elif new_value > w_max0: + out_w[i, j] = w_max0 + else: + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] - dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, - gpu_kernel=_gpu_dense_on_pre) + dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) + +else: + dense_on_pre_prim = None + dense_on_post_prim = None -@check_taichi_func def dense_on_pre(weight, spike, trace, w_min, w_max): + if dense_on_pre_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return dense_on_pre_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) - -dense_on_post_prim = None -if ti is not None: - @ti.kernel - def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - - - @ti.kernel - def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - - - dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, - gpu_kernel=_gpu_dense_on_post) - - -@check_taichi_func def dense_on_post(weight, spike, trace, w_min, w_max): + if dense_on_post_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) - if dense_on_post_prim is None: - import_taichi() return dense_on_post_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] @@ -756,107 +711,168 @@ def _batch_csrmv(self, x): transpose=self.transpose) -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): -# out_w[:] = w -# w_min = w_min[()] -# w_max = w_max[()] -# for i in numba.prange(spike.shape[0]): # pre id -# if spike[i]: -# for k in range(indptr[i], indptr[i + 1]): # synapse id -# j = indices[k] # post id -# # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) -# out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - -csr_on_pre_update_prim = None if ti is not None: @ti.kernel - def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] + def _csr_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1) + spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre = spike.shape[0] + for i_pre in range(num_pre): + if spike[i_pre]: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0) + else: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = old_w[i_syn] + + + csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) + + + @ti.kernel + def _coo_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if pre_spike[pre_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) + + + @ti.kernel + def _coo_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if post_spike[post_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + # out_w[:] = w + # w_min = w_min[()] + # w_max = w_max[()] + # for i in numba.prange(spike.shape[0]): # post id + # if spike[i]: + # for k in range(indptr[i], indptr[i + 1]): + # j = post_ids[k] # pre id + # l = w_ids[k] # syn id + # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + @ti.kernel - def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] + def _csc_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1) + w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + ): w_min0 = w_min[0] w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) + num_post = post_spike.shape[0] + for i_post in range(num_post): + if post_spike[i_post]: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0) + else: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = old_w[i_syn] + + + csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) - csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update, - gpu_kernel=_gpu_csr_on_pre_update) +else: + csr_on_pre_update_prim = None + coo_on_pre_update_prim = None + csc_on_post_update_prim = None -@check_taichi_func def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if csr_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) - if csr_on_pre_update_prim is None: - import_taichi() return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] -csc_on_pre_update_prim = None -if numba is not None: - @numba.njit(nogil=True, fastmath=True, parallel=False) - def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # post id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = post_ids[k] # pre id - l = w_ids[k] # syn id - out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) +def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None): + if coo_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] - csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) +def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None): + if csc_on_post_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') -@check_numba_func -def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, - outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] class CSCLinear(Layer): diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index 7ffc4e763..18d9d9dc9 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- -import pytest import numpy as np +import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -18,7 +17,7 @@ class Test_STDP(parameterized.TestCase): @parameterized.product( - comm_method=['dense', 'csr', 'masked_linear', 'all2all', 'one2one'], + comm_method=['csr', 'dense', 'masked_linear', 'all2all', 'one2one'], delay=[None, 0., 2.], syn_model=['exp', 'dual_exp', 'ampa'], out_model=['cuba', 'coba', 'mg'] @@ -102,9 +101,11 @@ def update(self, I_pre, I_post): duration = 300. I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) + [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, + duration - 255]) I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) + [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, + duration - 250]) net = STDPNet(1, 1) diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index bac809388..6b7f7da02 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -17,7 +17,7 @@ import numpy as np from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi, check_taichi_func +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi @@ -30,7 +30,6 @@ ti = import_taichi(error_if_not_found=False) -@check_taichi_func def csrmv( data: Union[float, jax.Array], indices: jax.Array, @@ -45,48 +44,6 @@ def csrmv( This function supports JAX transformations, including `jit()`, `grad()`, `vmap()` and `pmap()`. - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - - -def csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - Parameters ---------- data: ndarray, float @@ -164,7 +121,7 @@ def raw_csrmv_taichi( transpose: bool = False ): if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators') if transpose: if events.dtype == jnp.bool_: diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index b2de30b01..976b72b96 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -6,10 +6,11 @@ import numpy as np from jax import numpy as jnp -from brainpy._src.dependency_check import import_taichi, check_taichi_func +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn._matvec import (mv_prob_homo, mv_prob_uniform, + mv_prob_normal, _general_checking, raw_mv_prob_homo, raw_mv_prob_uniform, @@ -30,64 +31,8 @@ 'event_mv_prob_normal', ] -@check_taichi_func -def event_mv_prob_homo( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ - -@check_taichi_func -def event_mv_prob_uniform( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ - -@check_taichi_func -def event_mv_prob_normal( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def event_mv_prob_homo_taichi( +def event_mv_prob_homo( events: jax.Array, weight: float, conn_prob: float, @@ -97,56 +42,8 @@ def event_mv_prob_homo_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') events = as_jax(events) if isinstance(weight, float): weight = as_jax(weight) @@ -163,7 +60,10 @@ def event_mv_prob_homo_taichi( outdim_parallel=outdim_parallel)[0] -def event_mv_prob_uniform_taichi( +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform( events: jax.Array, w_low: float, w_high: float, @@ -174,58 +74,8 @@ def event_mv_prob_uniform_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') events = as_jax(events) if isinstance(w_low, float): w_low = as_jax(w_low) @@ -242,7 +92,10 @@ def event_mv_prob_uniform_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def event_mv_prob_normal_taichi( +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( events: jax.Array, w_mu: float, w_sigma: float, @@ -253,58 +106,8 @@ def event_mv_prob_normal_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') events = as_jax(events) if isinstance(w_mu, float): w_mu = as_jax(w_mu) @@ -321,9 +124,12 @@ def event_mv_prob_normal_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] +event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ + if ti is not None: from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + # ------------- # CPU function # ------------- diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 894294c79..00e5778f9 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -8,7 +8,7 @@ from jax import numpy as jnp from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi, check_taichi_func +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype from brainpy._src.math.op_register import XLACustomOp @@ -23,48 +23,6 @@ ] -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -@check_taichi_func def mv_prob_homo( vector: Union[Array, jax.Array], weight: float, @@ -123,11 +81,24 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(weight, float): + weight = as_jax(weight, dtype=vector.dtype) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.asarray(seed, dtype=jnp.uint32) + seed = jnp.atleast_1d(seed) + return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] -@check_taichi_func def mv_prob_uniform( vector: jax.Array, w_low: float, @@ -189,11 +160,24 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) + if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] -@check_taichi_func def mv_prob_normal( vector: jax.Array, w_mu: float, @@ -255,235 +239,8 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -def mv_prob_homo_taichi( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of - the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. - - Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same - of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_uniform_taichi( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_taichi( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - if ti is None: - raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') vector = as_jax(vector) if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) @@ -585,6 +342,47 @@ def raw_mv_prob_normal( outdim_parallel=outdim_parallel) +def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + if vector.ndim != 1: + raise ValueError('vector should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + + assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + + for weight in weights: + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + out_shape = (shape[1],) + if vector.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') + shape = _reverse(shape) + else: + if vector.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') + out_shape = (shape[0],) + + return shape, out_shape + + +def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + def _mv_prob_homo_transpose( ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel ): @@ -644,6 +442,7 @@ def _mv_prob_normal_transpose( assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + def _reverse(shape): return shape[::-1] @@ -774,9 +573,6 @@ def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) @@ -936,9 +732,6 @@ def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_mv_prob_uniform_jvp_vector, @@ -1103,9 +896,6 @@ def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, out transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) prim.defjvp(_mv_prob_normal_jvp_vector, diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index ead0cf00e..ca070a197 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -5,8 +5,7 @@ import numpy as np from jax.interpreters import xla, batching, ad, mlir - -from brainpy._src.dependency_check import import_numba, check_numba_class, check_taichi_class +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -22,8 +21,6 @@ from brainpy._src.math.op_register.ad_support import defjvp numba = import_numba(error_if_not_found=False) -if numba is not None: - from numba.core.dispatcher import Dispatcher __all__ = [ 'XLACustomOp', @@ -40,8 +37,7 @@ def shape(self) -> Tuple[int, ...]: def dtype(self) -> np.dtype: ... -@check_numba_class -@check_taichi_class + class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. @@ -110,24 +106,30 @@ def __init__( self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) # cpu function + cpu_checked = False if cpu_kernel is None: - pass - elif isinstance(cpu_kernel, Dispatcher): # numba - register_numba_cpu_translation_rule(self.primitive, cpu_kernel) - elif hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi + cpu_checked = True + if numba is not None: # numba + from numba.core.dispatcher import Dispatcher + if isinstance(cpu_kernel, Dispatcher): + register_numba_cpu_translation_rule(self.primitive, cpu_kernel) + cpu_checked = True + if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi register_taichi_cpu_translation_rule(self.primitive, cpu_kernel) - else: + cpu_checked = True + if not cpu_checked: raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. ' f'But we got {cpu_kernel}') # gpu function + gpu_checked = False if gpu_kernel is None: - pass - elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi + gpu_checked = True + if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) - else: - raise ValueError(f'"cpu_kernel" must be a taichi kernel function. ' - f'But we got {gpu_kernel}') + gpu_checked = True + if not gpu_checked: + raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 2af5637b4..5bbd04e0c 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -8,16 +8,14 @@ from jax.interpreters import xla, batching, ad from jax.tree_util import tree_map -from brainpy._src.dependency_check import import_numba, check_numba_func, check_numba_class +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +from brainpy.errors import PackageMissingError +from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba numba = import_numba(error_if_not_found=False) -from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba - -if numba is not None: - from numba.core.dispatcher import Dispatcher __all__ = [ 'CustomOpByNumba', @@ -26,7 +24,6 @@ ] -@check_numba_class class CustomOpByNumba(BrainPyObject): """Creating a XLA custom call operator with Numba JIT on CPU backend. @@ -88,7 +85,6 @@ def __call__(self, *args, **kwargs): return res -@check_numba_func def register_op_with_numba( op_name: str, cpu_func: Callable, @@ -143,6 +139,9 @@ def register_op_with_numba( f'For more information, please refer to the documentation: ' f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') + if numba is None: + raise PackageMissingError.by_purpose('numba', 'custom op with numba') + if out_shapes is None: raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' @@ -152,6 +151,7 @@ def register_op_with_numba( prim.multiple_results = multiple_results # user defined function + from numba.core.dispatcher import Dispatcher if not isinstance(cpu_func, Dispatcher): cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 759ecc50c..4b06effdf 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -6,9 +6,15 @@ from jax.core import ShapedArray from jax.lib import xla_client -from brainpy._src.dependency_check import import_numba, check_numba_func +from brainpy._src.dependency_check import import_numba numba = import_numba(error_if_not_found=False) +ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, # void* pointer + ctypes.c_char_p, # const char *name + ctypes.c_void_p, # PyCapsule_Destructor destructor +] +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object __all__ = [ '_cpu_translation', @@ -18,14 +24,7 @@ if numba is not None: from numba import types, carray, cfunc -ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, # void* pointer - ctypes.c_char_p, # const char *name - ctypes.c_void_p, # PyCapsule_Destructor destructor -] -ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object -@check_numba_func def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): target_name, inputs, input_shapes, xla_output_shapes = \ compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) @@ -102,7 +101,7 @@ def xla_cpu_custom_call_target(output_ptrs, input_ptrs): xla_client.register_custom_call_target(target_name, capsule, "cpu") return target_name -@check_numba_func + def compile_cpu_signature_with_numba( c, func, diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index 8c56e52aa..f461f4277 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -7,8 +7,9 @@ from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call +from brainpy._src.dependency_check import import_numba +from brainpy.errors import PackageMissingError from .utils import _shape_to_layout -from brainpy._src.dependency_check import import_numba, check_numba_func numba = import_numba(error_if_not_found=False) if numba is not None: @@ -105,8 +106,10 @@ def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): ) -@check_numba_func def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + # do not support after jax >= 0.4.24 xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, cpu_kernel, @@ -170,7 +173,9 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs): ).results -@check_numba_func def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug) mlir.register_lowering(primitive, rule, platform='cpu') diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 43ccac6c8..19800749d 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -10,13 +10,14 @@ from jax.interpreters import ad, xla from jax.lib import xla_client -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba, check_numba_func +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) from brainpy.errors import GPUOperatorNotFound numba = import_numba(error_if_not_found=False) + __all__ = [ 'bcsrmm', ] @@ -216,7 +217,6 @@ def blocksparse_matmat_multiply(dense_a, raise Exception('Invalid device: ', device) -@check_numba_func def bcsrmm( A_data: jax.Array, B_data: jax.Array, diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index dd25ef3d4..42969f435 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -3,17 +3,16 @@ from typing import Union, Tuple -import brainpy.math as bm import jax from jax import numpy as jnp from jax.experimental.sparse import csr from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi, check_taichi_func +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (register_general_batching, - XLACustomOp) +from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo from brainpy.errors import PackageMissingError @@ -23,7 +22,7 @@ 'csrmv', ] -@check_taichi_func + def csrmv( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -64,48 +63,6 @@ def csrmv( - ``vector``: - ``adaptive``: - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) - - -### TAICHI ### - -def csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -) -> jax.Array: - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - Returns ------- y : ndarry @@ -150,11 +107,11 @@ def raw_csrmv_taichi( transpose: bool = False, ): if ti is None: - raise PackageMissingError(name='taichi', purpose='customized operators') + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') out_shape = shape[1] if transpose else shape[0] if data.shape[0] != 1: if bm.get_platform() == 'gpu': - return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose), ] + return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)] else: if transpose: prim = _csr_matvec_transpose_heter_p diff --git a/brainpy/errors.py b/brainpy/errors.py index 37d4b9488..453c9c818 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -39,15 +39,11 @@ class PackageMissingError(BrainPyError): """The package missing error. """ - def __init__(self, name: str = None, purpose: str = None): - - if name is None: - super().__init__() - else: - assert purpose, '"purpose" cannot be None when "name" is provided.' - msg = (f'"{name}" must be installed when the user wants to use {purpose}. \n' - f'Please install through "pip install {name}".') - super().__init__(msg) + @classmethod + def by_purpose(cls, name, purpose): + err = (f'"{name}" must be installed when the user wants to use {purpose}. \n' + f'Please install through "pip install {name}".') + return cls(err) class BackendNotInstalled(BrainPyError): From 6165a95369dd4933434cbe23f9091a656e5357f6 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 29 Feb 2024 13:52:09 +0800 Subject: [PATCH 12/16] fix --- brainpy/_src/tools/progress.py | 519 +++++++++++++++++++ examples/dynamics_training/integrator_rnn.py | 2 +- 2 files changed, 520 insertions(+), 1 deletion(-) create mode 100644 brainpy/_src/tools/progress.py diff --git a/brainpy/_src/tools/progress.py b/brainpy/_src/tools/progress.py new file mode 100644 index 000000000..13b6a1574 --- /dev/null +++ b/brainpy/_src/tools/progress.py @@ -0,0 +1,519 @@ +"""Python utilities required by Keras.""" + +import binascii +import codecs +import importlib +import marshal +import os +import re +import sys +import time +import types as python_types + +import numpy as np + + +# isort: off + + +def func_dump(func): + """Serializes a user defined function. + + Args: + func: the function to serialize. + + Returns: + A tuple `(code, defaults, closure)`. + """ + if os.name == "nt": + raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") + code = codecs.encode(raw_code, "base64").decode("ascii") + else: + raw_code = marshal.dumps(func.__code__) + code = codecs.encode(raw_code, "base64").decode("ascii") + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure + + +def func_load(code, defaults=None, closure=None, globs=None): + """Deserializes a user defined function. + + Args: + code: bytecode of the function. + defaults: defaults of the function. + closure: closure of the function. + globs: dictionary of global objects. + + Returns: + A function object. + """ + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) + + def ensure_value_to_cell(value): + """Ensures that a value is converted to a python cell object. + + Args: + value: Any value that needs to be casted to the cell type + + Returns: + A value wrapped as a cell object (see function "func_load") + """ + + def dummy_fn(): + value # just access it so it gets captured in .__closure__ + + cell_value = dummy_fn.__closure__[0] + if not isinstance(value, type(cell_value)): + return cell_value + return value + + if closure is not None: + closure = tuple(ensure_value_to_cell(_) for _ in closure) + try: + raw_code = codecs.decode(code.encode("ascii"), "base64") + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode("raw_unicode_escape") + code = marshal.loads(raw_code) + if globs is None: + globs = globals() + return python_types.FunctionType( + code, globs, name=code.co_name, argdefs=defaults, closure=closure + ) + + +class Progbar: + """Displays a progress bar. + + Args: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that should *not* + be averaged over time. Metrics in this list will be displayed as-is. + All others will be averaged by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + unit_name: Display name for step counts (usually "step" or "sample"). + """ + + def __init__( + self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name="step", + ): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ( + (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) + or "ipykernel" in sys.modules + or "posix" in sys.modules + or "PYCHARM_HOSTED" in os.environ + ) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + self._time_at_epoch_start = self._start + self._time_at_epoch_end = None + self._time_after_first_step = None + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. + + Args: + current: Index of current step. + values: List of tuples: `(name, value_for_last_step)`. If `name` is + in `stateful_metrics`, `value_for_last_step` will be displayed + as-is. Else, an average of the metric over time will be + displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, uses `current >= self.target`. Defaults to `None`. + """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in + # the first epoch, both on_batch_end and on_epoch_end will be + # called, which will cause 'current' and 'self._seen_so_far' to + # have the same value. Force the minimal value to 1 here, + # otherwise stateful_metric will be 0s. + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + message = "" + now = time.time() + info = f" - {now - self._start:.0f}s" + if current == self.target: + self._time_at_epoch_end = now + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + prev_total_width = self._total_width + if self._dynamic_display: + message += "\b" * prev_total_width + message += "\r" + else: + message += "\n" + + if self.target is not None: + numdigits = int(np.log10(self.target)) + 1 + bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += "=" * (prog_width - 1) + if current < self.target: + bar += ">" + else: + bar += "=" + bar += "." * (self.width - prog_width) + bar += "]" + else: + bar = "%7d/Unknown" % current + + self._total_width = len(bar) + message += bar + + time_per_unit = self._estimate_step_duration(current, now) + + if self.target is None or finalize: + info += self._format_time(time_per_unit, self.unit_name) + else: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = "%d:%02d:%02d" % ( + eta // 3600, + (eta % 3600) // 60, + eta % 60, + ) + elif eta > 60: + eta_format = "%d:%02d" % (eta // 60, eta % 60) + else: + eta_format = "%ds" % eta + + info = f" - ETA: {eta_format}" + + for k in self._values_order: + info += f" - {k}:" + if isinstance(self._values[k], list): + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if abs(avg) > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + else: + info += f" {self._values[k]}" + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += " " * (prev_total_width - self._total_width) + + if finalize: + info += "\n" + + message += info + print_msg(message, line_break=False) + message = "" + + elif self.verbose == 2: + if finalize: + numdigits = int(np.log10(self.target)) + 1 + count = ("%" + str(numdigits) + "d/%d") % (current, self.target) + info = count + info + for k in self._values_order: + info += f" - {k}:" + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if avg > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + if self._time_at_epoch_end: + time_per_epoch = ( + self._time_at_epoch_end - self._time_at_epoch_start + ) + avg_time_per_step = time_per_epoch / self.target + self._time_at_epoch_start = now + self._time_at_epoch_end = None + info += " -" + self._format_time(time_per_epoch, "epoch") + info += " -" + self._format_time( + avg_time_per_step, self.unit_name + ) + info += "\n" + message += info + print_msg(message, line_break=False) + message = "" + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + def _format_time(self, time_per_unit, unit_name): + """format a given duration to display to the user. + + Given the duration, this function formats it in either milliseconds + or seconds and displays the unit (i.e. ms/step or s/epoch) + Args: + time_per_unit: the duration to display + unit_name: the name of the unit to display + Returns: + a string with the correctly formatted duration and units + """ + formatted = "" + if time_per_unit >= 1 or time_per_unit == 0: + formatted += f" {time_per_unit:.0f}s/{unit_name}" + elif time_per_unit >= 1e-3: + formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" + else: + formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" + return formatted + + def _estimate_step_duration(self, current, now): + """Estimate the duration of a single step. + + Given the step number `current` and the corresponding time `now` this + function returns an estimate for how long a single step takes. If this + is called before one step has been completed (i.e. `current == 0`) then + zero is given as an estimate. The duration estimate ignores the duration + of the (assumed to be non-representative) first step for estimates when + more steps are available (i.e. `current>1`). + + Args: + current: Index of current step. + now: The current time. + + Returns: Estimate of the duration of a single step. + """ + if current: + # there are a few special scenarios here: + # 1) somebody is calling the progress bar without ever supplying + # step 1 + # 2) somebody is calling the progress bar and supplies step one + # multiple times, e.g. as part of a finalizing call + # in these cases, we just fall back to the simple calculation + if self._time_after_first_step is not None and current > 1: + time_per_unit = (now - self._time_after_first_step) / ( + current - 1 + ) + else: + time_per_unit = (now - self._start) / current + + if current == 1: + self._time_after_first_step = now + return time_per_unit + else: + return 0 + + def _update_stateful_metrics(self, stateful_metrics): + self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) + + +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Args: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [ + (i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches) + ] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `slice_arrays(x, indices)` + + Args: + arrays: Single array or list of arrays. + start: can be an integer index (start index) or a list/array of indices + stop: integer (stop index); should be None if `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError( + "The stop argument has to be None if the value of start " + f"is a list. Received start={start}, stop={stop}" + ) + elif isinstance(arrays, list): + if hasattr(start, "__len__"): + # hdf5 datasets only support list objects as indices + if hasattr(start, "shape"): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + return [ + None + if x is None + else None + if not hasattr(x, "__getitem__") + else x[start:stop] + for x in arrays + ] + else: + if hasattr(start, "__len__"): + if hasattr(start, "shape"): + start = start.tolist() + return arrays[start] + if hasattr(start, "__getitem__"): + return arrays[start:stop] + return [None] + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Args: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] + + +def to_snake_case(name): + intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != "_": + return insecure + return "private" + insecure + + +def check_for_unexpected_keys(name, input_dict, expected_values): + unknown = set(input_dict.keys()).difference(expected_values) + if unknown: + raise ValueError( + f"Unknown entries in {name} dictionary: {list(unknown)}. " + f"Only expected following keys: {expected_values}" + ) + + +def validate_kwargs( + kwargs, allowed_kwargs, error_message="Keyword argument not understood:" +): + """Checks that all keyword arguments are in the set of allowed keys.""" + for kwarg in kwargs: + if kwarg not in allowed_kwargs: + raise TypeError(error_message, kwarg) + + +def default(method): + """Decorates a method to detect overrides in subclasses.""" + method._is_default = True + return method + + +def is_default(method): + """Check if a method is decorated with the `default` wrapper.""" + return getattr(method, "_is_default", False) + + +def populate_dict_with_module_objects(target_dict, modules, obj_filter): + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if obj_filter(obj): + target_dict[name] = obj + + +class LazyLoader(python_types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies.""" + + def __init__(self, local_name, parent_module_globals, name): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + super().__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on + # lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + +def print_msg(message, line_break=True): + """Print the message to absl logging or stdout.""" + if line_break: + sys.stdout.write(message + "\n") + else: + sys.stdout.write(message) + sys.stdout.flush() diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py index fc36845e6..d0dfca11b 100644 --- a/examples/dynamics_training/integrator_rnn.py +++ b/examples/dynamics_training/integrator_rnn.py @@ -30,7 +30,7 @@ def train_data(): class RNN(bp.DynamicalSystem): def __init__(self, num_in, num_hidden): super(RNN, self).__init__() - self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True) + self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) self.out = bp.layers.Dense(num_hidden, 1) def update(self, x): From 285b48d48e0192d23a6eb0eb8687cfdce5d827be Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 29 Feb 2024 13:58:21 +0800 Subject: [PATCH 13/16] update --- examples/dynamics_training/integrator_rnn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py index d0dfca11b..aeaf0c412 100644 --- a/examples/dynamics_training/integrator_rnn.py +++ b/examples/dynamics_training/integrator_rnn.py @@ -49,7 +49,7 @@ def loss(predictions, targets, l2_reg=2e-4): # define optimizer -lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) +lr = bp.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975) opt = bp.optim.Adam(lr=lr, eps=1e-1) # create a trainer From decc7cf2a75056db151158b9349badd161ccc79d Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 29 Feb 2024 14:02:25 +0800 Subject: [PATCH 14/16] update --- brainpy/_src/dependency_check.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 2820c7081..b8bd6e99a 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -29,6 +29,11 @@ def import_taichi(error_if_not_found=True): + """Internal API to import taichi. + + If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ global taichi if taichi is None: with open(os.devnull, 'w') as devnull: @@ -54,6 +59,12 @@ def raise_taichi_not_found(*args, **kwargs): def import_numba(error_if_not_found=True): + """ + Internal API to import numba. + + If numba is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ global numba if numba is None: try: @@ -75,6 +86,9 @@ def is_brainpylib_gpu_installed(): def import_brainpylib_cpu_ops(): + """ + Internal API to import brainpylib cpu_ops. + """ global brainpylib_cpu_ops if brainpylib_cpu_ops is None: try: @@ -97,6 +111,9 @@ def import_brainpylib_cpu_ops(): def import_brainpylib_gpu_ops(): + """ + Internal API to import brainpylib gpu_ops. + """ global brainpylib_gpu_ops if brainpylib_gpu_ops is None: try: From c3870f3081e59df53fbc642d9ea84ff23b9bc7fb Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 29 Feb 2024 14:31:00 +0800 Subject: [PATCH 15/16] update doc and dependency --- README.md | 24 +-- docs/quickstart/installation.rst | 262 +++---------------------------- setup.py | 11 +- 3 files changed, 32 insertions(+), 265 deletions(-) diff --git a/README.md b/README.md index 6d2ee4bf4..a7fe0b721 100644 --- a/README.md +++ b/README.md @@ -25,29 +25,7 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu ## Installation -BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Install the latest version of BrainPy: - -```bash -$ pip install brainpy -U -``` - -In addition, many customized operators in BrainPy are implemented in ``brainpylib``. -Install the latest version of `brainpylib` by: - -```bash -# CPU installation for Linux, macOS and Windows -$ pip install --upgrade brainpylib -``` - -```bash -# CUDA 12 installation for Linux only -$ pip install --upgrade brainpylib-cu12x -``` - -```bash -# CUDA 11 installation for Linux only -$ pip install --upgrade brainpylib-cu11x -``` +BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html) diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index 2e0bb1905..46ce3822f 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -10,285 +10,71 @@ Installation Linux, and MacOS. It only relies on Python libraries. -Installation with pip ---------------------- +Minimum requirements +-------------------- -You can install ``BrainPy`` from the `pypi `_. -To do so, use: +To install brainpy with minimum requirements (only depends on ``jax``), you can use: .. code-block:: bash - pip install brainpy - -To update the latest BrainPy, you can use - -.. code-block:: bash - - pip install -U brainpy - - -If you want to install the pre-release version (the latest development version) -of BrainPy, you can use: - -.. code-block:: bash - - pip install --pre brainpy - - - -Installation from source ------------------------- - -If you decide not to use ``pip``, you can install ``BrainPy`` from -`GitHub `_, -or `OpenI `_. - -To do so, use: - -.. code-block:: bash - - pip install git+https://github.com/PKU-NIP-Lab/BrainPy + pip install brainpy[cpu_mini] # for CPU # or - pip install git+https://git.openi.org.cn/OpenI/BrainPy + pip install brainpy[cuda_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for GPU (Linux only) -Dependency 1: NumPy --------------------------------- -In order to make BrainPy work normally, users should install -several dependent Python packages. +CPU with all dependencies +------------------------- -The basic function of ``BrainPy`` only relies on `NumPy`_, which is very -easy to install through ``pip`` or ``conda``: +To install a CPU-only version of BrainPy, which might be useful for doing local development on a laptop, you can run .. code-block:: bash - pip install numpy - - # or - - conda install numpy - -Dependency 2: JAX ------------------ - -BrainPy relies on `JAX`_. JAX is a high-performance JIT compiler which enables -users to run Python code on CPU, GPU, and TPU devices. Core functionalities of -BrainPy (>=2.0.0) have been migrated to the JAX backend. - -Linux -^^^^^ - -Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or -later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS -systems are available at + pip install brainpy[cpu] -- for CPU: https://storage.googleapis.com/jax-releases/jax_releases.html -- for GPU: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -If you want to install a CPU-only version of `jax` and `jaxlib`, you can run +GPU with all dependencies +------------------------- -.. code-block:: bash - - pip install --upgrade "jax[cpu]" - -If you want to install JAX with both CPU and NVidia GPU support, you must first install -`CUDA`_ and `CuDNN`_, if they have already been installed. Next, run +BrainPy supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. +To install a GPU-only version of BrainPy, you can run .. code-block:: bash - # CUDA 12 installation - # Note: wheels only available on linux. - pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - - # CUDA 11 installation - # Note: wheels only available on linux. - pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -In the event of a version mismatch error with JAX, such as encountering an error message like: - -.. code-block:: text + pip install brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0 + pip install brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0 - CUDA backend failed to initialize: Found CUDA version 12000, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) -You will need to employ an alternative installation method that aligns with your environment's CUDA version. This can be achieved using the following commands: -.. code-block:: bash +``brainpylib`` +-------------- - # CUDA 12 installation - pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - # CUDA 11 installation - pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``brainpylib`` defines a set of useful operators for building and simulating spiking neural networks. -Alternatively, you can download the preferred release ".whl" file for jaxlib -from the above release links, and install it via ``pip``: +To install the ``brainpylib`` package on CPU devices, you can run .. code-block:: bash - pip install xxx-0.4.15-xxx.whl - - pip install jax==0.4.15 - -.. note:: - - Note that the versions of jaxlib and jax should be consistent. - - For example, if you are using jax==0.4.15, you would better install jax==0.4.15. - + pip install brainpylib -MacOS -^^^^^ -If you are using macOS Intel, we recommend you first to install the Miniconda Intel installer: +To install the ``brainpylib`` package on CUDA 11, you can run -1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.pkg -2. Then click the downloaded package and install it. - - -If you are using the latest M1 macOS version, you'd better to install the Miniconda M1 installer: - - -1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.pkg -2. Then click the downloaded package and install it. - - -Finally, you can install `jax` and `jaxlib` as the same as the Linux platform. .. code-block:: bash - pip install --upgrade "jax[cpu]" - - - -Windows -^^^^^^^ - -For **Windows** users with Python >= 3.9, `jax` and `jaxlib` can be installed -directly from the PyPi channel. - -.. code-block:: bash + pip install brainpylib-cu11x - pip install jax jaxlib +To install the ``brainpylib`` package on CUDA 12, you can run -For **Windows** users with Python <= 3.8, `jax` and `jaxlib` can be installed -from the community supports. Specifically, you can install `jax` and `jaxlib` through: .. code-block:: bash - pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html - -If you are using GPU, you can install GPU-versioned wheels through: - -.. code-block:: bash - - pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html - -Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by -downloading binary releases of JAX for Windows from -https://whls.blob.core.windows.net/unstable/index.html . -Then install it via ``pip``: - -.. code-block:: bash - - pip install xxx-0.4.15-xxx.whl - - pip install jax==0.4.15 - -WSL -^^^ - -Moreover, for Windows 10+ system, we recommend using `Windows Subsystem for Linux (WSL)`_. -The installation guide can be found in -`WSL Installation Guide for Windows 10/11 `_. -Then, you can install JAX in WSL just like the installation step in Linux/MacOs. - - -Dependency 3: brainpylib ------------------------- - -Many customized operators in BrainPy are implemented in ``brainpylib``. -``brainpylib`` can also be installed from pypi according to your devices. -For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators. -You can install CPU-version `brainpylib` by: - -.. code-block:: bash - - # CPU installation - pip install --upgrade brainpylib - -For Nvidia GPU users, ``brainpylib`` only support Linux system and WSL2 subsystem. You can install the CUDA-version by using: - -.. code-block:: bash - - # CUDA 12 installation - pip install --upgrade brainpylib-cu12x - -.. code-block:: bash - - # CUDA 11 installation - pip install --upgrade brainpylib-cu11x - -Dependency 4: taichi ------------------------- -Now BrainPy supports customized operators implemented in `taichi`_. You can install the latest version of `taichi`_ by: - -.. code-block:: bash - - pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly - -.. _taichi: https://www.taichi-lang.org - -And you can try it in the `operator custom with taichi <../tutorial_advanced/operator_custom_with_taichi.html>`_ tutorial page -Attention: customized operators is still in the experimental stage. If you meet any problems, please contact us through the issue page. - -Running BrainPy with docker ------------------------- - -If you want to use BrainPy in docker, you can use the following command to pull the docker image: - -.. code:: bash - - docker pull brainpy/brainpy:latest - -You can then run the docker image by: - -.. code:: bash - - docker run -it --platform linux/amd64 brainpy/brainpy:latest - -Please notice that BrainPy docker image is based on the `ubuntu22.04` image, so it only support CPU version of BrainPy. - - -Running BrainPy online with binder ----------------------------------- - -Click on the following link to launch the Binder environment with the -BrainPy repository: - -|image1| - -Wait for the Binder environment to build. This might take a few moments. - -Once the environment is ready, you'll be redirected to a Jupyter -notebook interface within your web browser. - -.. |image1| image:: https://camo.githubusercontent.com/581c077bdbc6ca6899c86d0acc6145ae85e9d80e6f805a1071793dbe48917982/68747470733a2f2f6d7962696e6465722e6f72672f62616467655f6c6f676f2e737667 - :target: https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main - - -.. _NumPy: https://numpy.org/ -.. _Matplotlib: https://matplotlib.org/ -.. _JAX: https://github.com/google/jax -.. _Windows Subsystem for Linux (WSL): https://docs.microsoft.com/en-us/windows/wsl/about -.. _build JAX from source: https://jax.readthedocs.io/en/latest/developer.html -.. _SymPy: https://github.com/sympy/sympy -.. _Numba: https://numba.pydata.org/ -.. _CUDA: https://developer.nvidia.com/cuda-downloads -.. _CuDNN: https://developer.nvidia.com/CUDNN + pip install brainpylib-cu12x diff --git a/setup.py b/setup.py index 766cd8c75..885bbf57b 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ author='BrainPy Team', author_email='chao.brain@qq.com', packages=packages, - python_requires='>=3.8', + python_requires='>=3.9', install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'], url='https://github.com/brainpy/BrainPy', project_urls={ @@ -69,9 +69,11 @@ ], extras_require={ 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], - 'cuda11': ['jax[cuda11_pip]', 'brainpylib-cu11x', 'numba', 'taichi==1.7.0'], - 'cuda12': ['jax[cuda12_pip]', 'brainpylib-cu12x', 'numba', 'taichi==1.7.0'], - 'tpu': ['jax[tpu]', 'numba', 'taichi==1.7.0'], + 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib-cu11x', 'numba', 'taichi==1.7.0'], + 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib-cu12x', 'numba', 'taichi==1.7.0'], + 'tpu': ['jaxlib[tpu]', 'numba',], + 'cpu_mini': ['jaxlib>=0.4.13'], + 'cuda_mini': ['jaxlib[cuda12_pip]'], }, keywords=('computational neuroscience, ' 'brain-inspired computation, ' @@ -88,6 +90,7 @@ 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Topic :: Scientific/Engineering :: Bio-Informatics', From da9bf7e0db7a88d2f597ed3a25ff8f731154e8cf Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 29 Feb 2024 14:33:14 +0800 Subject: [PATCH 16/16] update dependency --- .github/workflows/CI.yml | 4 ++-- requirements-dev-raw.txt | 12 ++++++++++++ requirements-dev.txt | 3 +++ 3 files changed, 17 insertions(+), 2 deletions(-) create mode 100644 requirements-dev-raw.txt diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b82507108..95bd8eafd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -67,7 +67,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi pip uninstall brainpy -y python setup.py install - name: Lint with flake8 @@ -164,7 +164,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi + if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi pip uninstall brainpy -y python setup.py install - name: Lint with flake8 diff --git a/requirements-dev-raw.txt b/requirements-dev-raw.txt new file mode 100644 index 000000000..99361efa9 --- /dev/null +++ b/requirements-dev-raw.txt @@ -0,0 +1,12 @@ +numpy +jax +jaxlib +matplotlib +msgpack +tqdm +pathos + + +# test requirements +pytest +absl-py diff --git a/requirements-dev.txt b/requirements-dev.txt index 167f39df9..98398ae2d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,6 +6,9 @@ matplotlib msgpack tqdm pathos +taichi +numba + # test requirements pytest