Skip to content

Commit

Permalink
Merge pull request #244 from chaoming0625/master
Browse files Browse the repository at this point in the history
update quickstart docs & enable jit error checking
  • Loading branch information
chaoming0625 authored Aug 10, 2022
2 parents cc2cd73 + 8ba66a2 commit b65b766
Show file tree
Hide file tree
Showing 17 changed files with 2,055 additions and 1,655 deletions.
19 changes: 12 additions & 7 deletions brainpy/dyn/synapses/delay_couplings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from jax import vmap

import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.base import SynConn, SynOut
from brainpy.dyn.synouts import CUBA
from brainpy.initialize import Initializer
from brainpy.dyn.neurons.input_groups import InputGroup, OutputGroup
from brainpy.modes import Mode, TrainingMode, normal
from brainpy.tools.checking import check_sequence
from brainpy.types import Tensor
Expand All @@ -19,7 +21,7 @@
]


class DelayCoupling(DynamicalSystem):
class DelayCoupling(SynConn):
"""Delay coupling.
Parameters
Expand Down Expand Up @@ -49,7 +51,10 @@ def __init__(
name: str = None,
mode: Mode = normal,
):
super(DelayCoupling, self).__init__(name=name, mode=mode)
super(DelayCoupling, self).__init__(name=name,
mode=mode,
pre=InputGroup(1),
post=OutputGroup(1))

# delay variable
if not isinstance(delay_var, bm.Variable):
Expand Down Expand Up @@ -201,8 +206,8 @@ def update(self, tdi):
indices = (slice(None, None, None), bm.arange(self.coupling_var1.size),)
else:
indices = (bm.arange(self.coupling_var1.size),)
f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (..., pre.num)
delays = f(bm.arange(self.coupling_var2.size).value) # (..., post.num, pre.num)
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (..., pre.num)
delays = f(self.delay_steps) # (..., post.num, pre.num)
diffusive = (bm.moveaxis(delays, axis - 1, axis) -
bm.expand_dims(self.coupling_var2, axis=axis - 1)) # (..., pre.num, post.num)
diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1)
Expand Down Expand Up @@ -284,8 +289,8 @@ def update(self, tdi):
indices = (slice(None, None, None), bm.arange(self.coupling_var.size),)
else:
indices = (bm.arange(self.coupling_var.size),)
f = vmap(lambda i: delay_var(self.delay_steps[:, i], *indices)) # (.., pre.num,)
delays = f(bm.arange(self.coupling_var.size).value) # (..., post.num, pre.num)
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (.., pre.num,)
delays = f(self.delay_steps) # (..., post.num, pre.num)
additive = (self.conn_mat * bm.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1)
elif self.delay_type == 'int':
delayed_var = delay_var(self.delay_steps) # (..., pre.num)
Expand Down
11 changes: 8 additions & 3 deletions brainpy/initialize/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,24 +98,29 @@ def variable(
return bm.Variable(data(new_shape), batch_axis=batch_axis)
elif batch_size_or_mode in (None, False):
return bm.Variable(data(var_shape))
else:
elif isinstance(batch_size_or_mode, int):
new_shape = var_shape[:batch_axis] + (int(batch_size_or_mode),) + var_shape[batch_axis:]
return bm.Variable(data(new_shape), batch_axis=batch_axis)
else:
raise ValueError('Unknown batch_size_or_mode.')

else:
if var_shape is not None:
if bm.shape(data) != var_shape:
raise ValueError(f'The shape of "data" {bm.shape(data)} does not match with "var_shape" {var_shape}')
if isinstance(batch_size_or_mode, NormalMode):
return bm.Variable(data(var_shape))
return bm.Variable(data)
elif isinstance(batch_size_or_mode, BatchingMode):
return bm.Variable(bm.expand_dims(data, axis=batch_axis), batch_axis=batch_axis)
elif batch_size_or_mode in (None, False):
return bm.Variable(data)
else:
elif isinstance(batch_size_or_mode, int):
return bm.Variable(bm.repeat(bm.expand_dims(data, axis=batch_axis),
int(batch_size_or_mode),
axis=batch_axis),
batch_axis=batch_axis)
else:
raise ValueError('Unknown batch_size_or_mode.')


def noise(
Expand Down
5 changes: 3 additions & 2 deletions brainpy/math/delayvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def reset(
# delay data
if self.data is None:
if batch_axis is None:
if hasattr(delay_target, 'batch_axis') and (delay_target.batch_axis is not None):
if isinstance(delay_target, Variable) and (delay_target.batch_axis is not None):
batch_axis = delay_target.batch_axis + 1
self.data = Variable(jnp.zeros((self.num_delay_step,) + delay_target.shape,
dtype=delay_target.dtype),
Expand All @@ -348,7 +348,8 @@ def reset(

def _check_delay(self, delay_len):
raise ValueError(f'The request delay length should be less than the '
f'maximum delay {self.num_delay_step}. But we got {delay_len}')
f'maximum delay {self.num_delay_step}. '
f'But we got {delay_len}')

def __call__(self, delay_len, *indices):
# check
Expand Down
80 changes: 80 additions & 0 deletions brainpy/math/remove_vmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# -*- coding: utf-8 -*-

from brainpy.math.numpy_ops import any, all
from jax.core import Primitive
from jax.interpreters import batching, mlir, xla
from jax.abstract_arrays import ShapedArray
import jax.numpy as jnp


__all__ = [
'remove_vmap'
]


def remove_vmap(x, op='any'):
if op == 'any':
return _any_without_vmap(x)
elif op == 'all':
return _all_without_vmap(x)
else:
raise ValueError(f'Do not support type: {op}')


_any_no_vmap_prim = Primitive('any_no_vmap')


def _any_without_vmap(x):
return _any_no_vmap_prim.bind(x)


def _any_without_vmap_imp(x):
return any(x)


def _any_without_vmap_abs(x):
return ShapedArray(shape=(), dtype=jnp.bool_)


def _any_without_vmap_batch(x, batch_axes):
(x, ) = x
return _any_without_vmap(x), batching.not_mapped


_any_no_vmap_prim.def_impl(_any_without_vmap_imp)
_any_no_vmap_prim.def_abstract_eval(_any_without_vmap_abs)
batching.primitive_batchers[_any_no_vmap_prim] = _any_without_vmap_batch
if hasattr(xla, "lower_fun"):
xla.register_translation(_any_no_vmap_prim,
xla.lower_fun(_any_without_vmap_imp, multiple_results=False, new_style=True))
mlir.register_lowering(_any_no_vmap_prim, mlir.lower_fun(_any_without_vmap_imp, multiple_results=False))


_all_no_vmap_prim = Primitive('all_no_vmap')


def _all_without_vmap(x):
return _all_no_vmap_prim.bind(x)


def _all_without_vmap_imp(x):
return all(x)


def _all_without_vmap_abs(x):
return ShapedArray(shape=(), dtype=jnp.bool_)


def _all_without_vmap_batch(x, batch_axes):
(x, ) = x
return _all_without_vmap(x), batching.not_mapped


_all_no_vmap_prim.def_impl(_all_without_vmap_imp)
_all_no_vmap_prim.def_abstract_eval(_all_without_vmap_abs)
batching.primitive_batchers[_all_no_vmap_prim] = _all_without_vmap_batch
if hasattr(xla, "lower_fun"):
xla.register_translation(_all_no_vmap_prim,
xla.lower_fun(_all_without_vmap_imp, multiple_results=False, new_style=True))
mlir.register_lowering(_all_no_vmap_prim, mlir.lower_fun(_all_without_vmap_imp, multiple_results=False))

19 changes: 7 additions & 12 deletions brainpy/tools/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,22 @@
]


def _make_err_func(f):
f2 = lambda arg, transforms: f(arg)

def err_f(x):
id_tap(f2, x)
return
return err_f


def check_error_in_jit(pred, err_f, err_arg=None):
def check_error_in_jit(pred, err_fun, err_arg=None):
"""Check errors in a jit function.
Parameters
----------
pred: bool
The boolean prediction.
err_f: callable
err_fun: callable
The error function, which raise errors.
err_arg: any
The arguments which passed into `err_f`.
"""
cond(pred, _make_err_func(err_f), lambda _: None, err_arg)
from brainpy.math.remove_vmap import remove_vmap

def err_f(x):
id_tap(lambda arg, transforms: err_fun(arg), x)
return
cond(remove_vmap(pred), err_f, lambda _: None, err_arg)

2 changes: 1 addition & 1 deletion docs/auto_generater.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def generate_datasets_docs(path='apis/auto/datasets/'):
header='Chaotic Systems')
write_module(module_name='brainpy.datasets.vision',
filename=os.path.join(path, 'vision.rst'),
header='Chaotic Systems')
header='Vision Datasets')


def generate_dyn_docs(path='apis/auto/dyn/'):
Expand Down
4 changes: 0 additions & 4 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ The code of BrainPy is open-sourced at GitHub:

quickstart/installation
quickstart/simulation
quickstart/rate_model
quickstart/training
quickstart/analysis

Expand Down Expand Up @@ -66,9 +65,6 @@ The code of BrainPy is open-sourced at GitHub:
tutorial_toolbox/synaptic_connections
tutorial_toolbox/synaptic_weights
tutorial_toolbox/optimizers
tutorial_toolbox/runners
tutorial_toolbox/inputs
tutorial_toolbox/monitors
tutorial_toolbox/saving_and_loading


Expand Down
Loading

0 comments on commit b65b766

Please sign in to comment.