Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] fix brainpy.math.scan #604

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,6 +940,8 @@ def scan(
):
"""``scan`` control flow with :py:class:`~.Variable`.

Similar to ``jax.lax.scan``.

.. versionadded:: 2.4.7

All returns in body function will be gathered
Expand Down Expand Up @@ -999,7 +1001,7 @@ def scan(
rets = jax.eval_shape(transform, init, operands)
cache_stack(body_fun, dyn_vars) # cache
if current_transform_number():
return rets[1]
return rets[0][1], rets[1]
del rets

transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll)
Expand Down
28 changes: 20 additions & 8 deletions brainpy/_src/math/object_transform/tests/test_controls.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
# -*- coding: utf-8 -*-
import sys
import tempfile
import unittest
from functools import partial

import jax
from jax import vmap

from absl.testing import parameterized
from jax._src import test_util as jtu
from jax import vmap

import brainpy as bp
import brainpy.math as bm
Expand Down Expand Up @@ -147,6 +144,25 @@ def f(carray, x):
expected = bm.expand_dims(expected, axis=-1)
self.assertTrue(bm.allclose(outs, expected))

def test2(self):
a = bm.Variable(1)

def f(carray, x):
carray += x
a.value += 1.
return carray, a

@bm.jit
def f_outer(carray, x):
carry, outs = bm.scan(f, carray, x, unroll=2)
return carry, outs

carry, outs = f_outer(bm.zeros(2), bm.arange(10))
self.assertTrue(bm.allclose(carry, 45.))
expected = bm.arange(1, 11).astype(outs.dtype)
expected = bm.expand_dims(expected, axis=-1)
self.assertTrue(bm.allclose(outs, expected))


class TestCond(unittest.TestCase):
def test1(self):
Expand Down Expand Up @@ -234,7 +250,6 @@ def F2(x):
self.assertTrue(bm.grad(F2)(9.0) == 18.)
self.assertTrue(bm.grad(F2)(11.0) == 1.)


def test_grad2(self):
def F3(x):
return bm.ifelse(conditions=(x >= 10, x >= 0),
Expand Down Expand Up @@ -519,6 +534,3 @@ def body(a):
file.seek(0)
out6 = file.read().strip()
self.assertTrue(out5 == out6)



1 change: 1 addition & 0 deletions docs/apis/brainpy.math.oo_transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Object-oriented Transformations
ifelse
for_loop
while_loop
scan
jit
cls_jit
to_object
Expand Down
Loading