Skip to content

Commit

Permalink
Merge pull request #511 from chaoming0625/master
Browse files Browse the repository at this point in the history
Compatible with `jax==0.4.16`
  • Loading branch information
chaoming0625 authored Sep 22, 2023
2 parents e6373e8 + 44adbc4 commit 9589ad8
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 43 deletions.
4 changes: 2 additions & 2 deletions brainpy/_src/dnn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,8 @@ def update(self, x):
x = (x - mean) * lax.rsqrt(var + lax.convert_element_type(self.epsilon, x.dtype))
x = x.reshape(origin_shape)
if self.affine:
x = x * lax.broadcast_to_rank(self.scale, origin_dim)
x = x + lax.broadcast_to_rank(self.bias, origin_dim)
x = x * lax.broadcast_to_rank(self.scale.value, origin_dim)
x = x + lax.broadcast_to_rank(self.bias.value, origin_dim)
return x


Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/dyn/ions/calcium.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


class Calcium(Ion):
"""Base class for modeling Calcium ion."""
pass


Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/dyn/ions/potassium.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


class Potassium(Ion):
"""Base class for modeling Potassium ion."""
pass


Expand Down
1 change: 1 addition & 0 deletions brainpy/_src/dyn/ions/sodium.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


class Sodium(Ion):
"""Base class for modeling Sodium ion."""
pass


Expand Down
10 changes: 5 additions & 5 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class STDP_Song2000(Projection):
.. math::
\begin{aligned}
\frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
\frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\
\frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\
\tag{1}\end{aligned}
\begin{aligned}
\frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
\frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\
\frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\
\end{aligned}
where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def compile_cpu_signature_with_numba(
input_dimensions,
output_dtypes,
output_shapes,
multiple_results)
multiple_results,
debug=True)
output_layouts = [xla_client.Shape.array_shape(*arg)
for arg in zip(output_dtypes, output_shapes, output_layouts)]
output_layouts = (xla_client.Shape.tuple_shape(output_layouts)
Expand Down
16 changes: 8 additions & 8 deletions docs/tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@
"outputs": [],
"source": [
"# the recurrent cell with trainable parameters\n",
"cell1 = bp.layers.ToFlaxRNNCell(bp.layers.Conv2dLSTMCell((28, 28),\n",
" in_channels=1,\n",
" out_channels=32,\n",
" kernel_size=(3, 3)))\n",
"cell2 = bp.layers.ToFlaxRNNCell(bp.layers.Conv2dLSTMCell((14, 14),\n",
" in_channels=32,\n",
" out_channels=64,\n",
" kernel_size=(3, 3)))"
"cell1 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((28, 28),\n",
" in_channels=1,\n",
" out_channels=32,\n",
" kernel_size=(3, 3)))\n",
"cell2 = bp.dnn.ToFlaxRNNCell(bp.dyn.Conv2dLSTMCell((14, 14),\n",
" in_channels=32,\n",
" out_channels=64,\n",
" kernel_size=(3, 3)))"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial_advanced/integrate_bp_lif_into_flax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@
"outputs": [],
"source": [
"# LIF neurons can be viewed as a recurrent cell without trainable parameters\n",
"cell1 = bp.layers.ToFlaxRNNCell(bp.neurons.LIF((28, 28, 32), **pars))\n",
"cell2 = bp.layers.ToFlaxRNNCell(bp.neurons.LIF((14, 14, 64), **pars))\n",
"cell3 = bp.layers.ToFlaxRNNCell(bp.neurons.LIF(256, **pars))"
"cell1 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((28, 28, 32), **pars))\n",
"cell2 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF((14, 14, 64), **pars))\n",
"cell3 = bp.dnn.ToFlaxRNNCell(bp.neurons.LIF(256, **pars))"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorial_advanced/integrate_flax_into_brainpy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,12 @@
"class Network(bp.DynamicalSystemNS):\n",
" def __init__(self):\n",
" super(Network, self).__init__()\n",
" self.cnn = bp.layers.FromFlax(\n",
" self.cnn = bp.dnn.FromFlax(\n",
" CNN(), # the model\n",
" bm.ones([1, 4, 28, 1]) # an example of the input used to initialize the model parameters\n",
" )\n",
" self.rnn = bp.layers.GRUCell(256, 100)\n",
" self.linear = bp.layers.Dense(100, 10)\n",
" self.rnn = bp.dyn.GRUCell(256, 100)\n",
" self.linear = bp.dnn.Dense(100, 10)\n",
"\n",
" def update(self, x):\n",
" x = self.cnn(x)\n",
Expand Down
30 changes: 15 additions & 15 deletions examples/dynamics_training/echo_state_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
class ESN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden, num_out):
super(ESN, self).__init__()
self.r = bp.layers.Reservoir(num_in,
num_hidden,
Win_initializer=bp.init.Uniform(-0.1, 0.1),
Wrec_initializer=bp.init.Normal(scale=0.1),
in_connectivity=0.02,
rec_connectivity=0.02,
comp_type='dense')
self.o = bp.layers.Dense(num_hidden,
num_out,
W_initializer=bp.init.Normal(),
mode=bm.training_mode)
self.r = bp.dyn.Reservoir(num_in,
num_hidden,
Win_initializer=bp.init.Uniform(-0.1, 0.1),
Wrec_initializer=bp.init.Normal(scale=0.1),
in_connectivity=0.02,
rec_connectivity=0.02,
comp_type='dense')
self.o = bp.dnn.Dense(num_hidden,
num_out,
W_initializer=bp.init.Normal(),
mode=bm.training_mode)

def update(self, x):
return x >> self.r >> self.o
Expand All @@ -29,10 +29,10 @@ class NGRC(bp.DynamicalSystem):
def __init__(self, num_in, num_out):
super(NGRC, self).__init__()

self.r = bp.layers.NVAR(num_in, delay=2, order=2)
self.o = bp.layers.Dense(self.r.num_out, num_out,
W_initializer=bp.init.Normal(0.1),
mode=bm.training_mode)
self.r = bp.dyn.NVAR(num_in, delay=2, order=2)
self.o = bp.dnn.Dense(self.r.num_out, num_out,
W_initializer=bp.init.Normal(0.1),
mode=bm.training_mode)

def update(self, x):
return x >> self.r >> self.o
Expand Down
4 changes: 2 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
numpy
numba
brainpylib
jax>=0.4.1, <0.4.16
jaxlib>=0.4.1, <0.4.16
jax
jaxlib
matplotlib>=3.4
msgpack
tqdm
Expand Down
4 changes: 2 additions & 2 deletions requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ numpy
tqdm
msgpack
numba
jax>=0.4.1, <0.4.16
jaxlib>=0.4.1, <0.4.16
jax
jaxlib
matplotlib>=3.4
scipy>=1.1.0
numba
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy
jax>=0.4.1, <0.4.16
jax
tqdm
msgpack
numba
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
author_email='chao.brain@qq.com',
packages=packages,
python_requires='>=3.8',
install_requires=['numpy>=1.15', 'jax>=0.4.1, <0.4.16', 'tqdm', 'msgpack', 'numba'],
install_requires=['numpy>=1.15', 'jax', 'tqdm', 'msgpack', 'numba'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
Expand Down

0 comments on commit 9589ad8

Please sign in to comment.