From f16d7e3de2d52b30de62bce667435710c4ef000e Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 10 Apr 2023 20:48:11 +0800 Subject: [PATCH 1/3] update tests --- brainpy/_src/checkpoints/tests/test_io.py | 100 +++++++++++----------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/brainpy/_src/checkpoints/tests/test_io.py b/brainpy/_src/checkpoints/tests/test_io.py index 666482b07..5abbe967e 100644 --- a/brainpy/_src/checkpoints/tests/test_io.py +++ b/brainpy/_src/checkpoints/tests/test_io.py @@ -12,7 +12,7 @@ def __init__(self, *args, **kwargs): rng = bm.random.RandomState() - class IO1(bp.dyn.DynamicalSystem): + class IO1(bp.DynamicalSystem): def __init__(self): super(IO1, self).__init__() @@ -21,7 +21,7 @@ def __init__(self): self.c = bm.Variable(bm.ones((3, 4))) self.d = bm.Variable(bm.ones((2, 3, 4))) - class IO2(bp.dyn.DynamicalSystem): + class IO2(bp.DynamicalSystem): def __init__(self): super(IO2, self).__init__() @@ -35,59 +35,59 @@ def __init__(self): io2.a2 = io1.a io2.b2 = io2.b - self.net = bp.dyn.Container(io1, io2) + self.net = bp.Container(io1, io2) print(self.net.vars().keys()) print(self.net.vars().unique().keys()) def test_h5(self): - bp.base.save_as_h5('io_test_tmp.h5', self.net.vars()) - bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) + bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) - bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars()) - bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) def test_h5_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_h5('io_test_tmp.h52', self.net.vars()) + bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) def test_npz(self): - bp.base.save_as_npz('io_test_tmp.npz', self.net.vars()) - bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True) + bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars()) + bp.checkpoints.io.load_by_npz('io_test_tmp.npz', self.net, verbose=True) - bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True) - bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True) + bp.checkpoints.io.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True) + bp.checkpoints.io.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True) def test_npz_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars()) + bp.checkpoints.io.save_as_npz('io_test_tmp.npz2', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True) + bp.checkpoints.io.load_by_npz('io_test_tmp.npz2', self.net, verbose=True) def test_pkl(self): - bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars()) - bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True) + bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl', self.net.vars()) + bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True) - bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars()) - bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True) + bp.checkpoints.io.save_as_pkl('io_test_tmp.pickle', self.net.vars()) + bp.checkpoints.io.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True) def test_pkl_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars()) + bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl2', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True) + bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True) def test_mat(self): - bp.base.save_as_mat('io_test_tmp.mat', self.net.vars()) - bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True) + bp.checkpoints.io.save_as_mat('io_test_tmp.mat', self.net.vars()) + bp.checkpoints.io.load_by_mat('io_test_tmp.mat', self.net, verbose=True) def test_mat_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars()) + bp.checkpoints.io.save_as_mat('io_test_tmp.mat2', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True) + bp.checkpoints.io.load_by_mat('io_test_tmp.mat2', self.net, verbose=True) class TestIO2(unittest.TestCase): @@ -96,7 +96,7 @@ def __init__(self, *args, **kwargs): rng = bm.random.RandomState() - class IO1(bp.dyn.DynamicalSystem): + class IO1(bp.DynamicalSystem): def __init__(self): super(IO1, self).__init__() @@ -105,7 +105,7 @@ def __init__(self): self.c = bm.Variable(bm.ones((3, 4))) self.d = bm.Variable(bm.ones((2, 3, 4))) - class IO2(bp.dyn.DynamicalSystem): + class IO2(bp.DynamicalSystem): def __init__(self): super(IO2, self).__init__() @@ -115,56 +115,56 @@ def __init__(self): io1 = IO1() io2 = IO2() - self.net = bp.dyn.Container(io1, io2) + self.net = bp.Container(io1, io2) print(self.net.vars().keys()) print(self.net.vars().unique().keys()) def test_h5(self): - bp.base.save_as_h5('io_test_tmp.h5', self.net.vars()) - bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) + bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) - bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars()) - bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) def test_h5_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_h5('io_test_tmp.h52', self.net.vars()) + bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) def test_npz(self): - bp.base.save_as_npz('io_test_tmp.npz', self.net.vars()) - bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True) + bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars()) + bp.checkpoints.io.load_by_npz('io_test_tmp.npz', self.net, verbose=True) - bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True) - bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True) + bp.checkpoints.io.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True) + bp.checkpoints.io.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True) def test_npz_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars()) + bp.checkpoints.io.save_as_npz('io_test_tmp.npz2', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True) + bp.checkpoints.io.load_by_npz('io_test_tmp.npz2', self.net, verbose=True) def test_pkl(self): - bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars()) - bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True) + bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl', self.net.vars()) + bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True) - bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars()) - bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True) + bp.checkpoints.io.save_as_pkl('io_test_tmp.pickle', self.net.vars()) + bp.checkpoints.io.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True) def test_pkl_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars()) + bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl2', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True) + bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True) def test_mat(self): - bp.base.save_as_mat('io_test_tmp.mat', self.net.vars()) - bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True) + bp.checkpoints.io.save_as_mat('io_test_tmp.mat', self.net.vars()) + bp.checkpoints.io.load_by_mat('io_test_tmp.mat', self.net, verbose=True) def test_mat_postfix(self): with self.assertRaises(ValueError): - bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars()) + bp.checkpoints.io.save_as_mat('io_test_tmp.mat2', self.net.vars()) with self.assertRaises(ValueError): - bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True) + bp.checkpoints.io.load_by_mat('io_test_tmp.mat2', self.net, verbose=True) From db2a16c1020c9f270a7224ea127789b3bff1ef0c Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 11 Apr 2023 11:03:41 +0800 Subject: [PATCH 2/3] internal changes --- brainpy/__init__.py | 2 +- brainpy/_src/math/object_transform/autograd.py | 10 +++++----- brainpy/_src/math/object_transform/controls.py | 16 ++++++++-------- brainpy/_src/math/object_transform/jit.py | 16 +++++++++++----- brainpy/math/object_base.py | 3 --- brainpy/math/object_transform.py | 13 +++++-------- docs/auto_generater.py | 8 ++++---- 7 files changed, 34 insertions(+), 34 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 6bb9f592f..336a2d3da 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.3.8" +__version__ = "2.4.0" # fundamental supporting modules diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index deae97500..d5db0ace8 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -17,11 +17,11 @@ from brainpy import tools, check from brainpy._src.math.ndarray import Array -from brainpy._src.math.object_transform.variables import Variable -from brainpy._src.math.object_transform.base import BrainPyObject, ObjectTransform -from brainpy._src.math.object_transform._tools import (dynvar_deprecation, - node_deprecation, - evaluate_dyn_vars) +from .variables import Variable +from .base import BrainPyObject, ObjectTransform +from ._tools import (dynvar_deprecation, + node_deprecation, + evaluate_dyn_vars) __all__ = [ 'grad', # gradient of scalar function diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3dd582464..3ddc5d1e8 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -12,14 +12,14 @@ from brainpy import errors, tools, check from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import (Array, ) -from brainpy._src.math.object_transform._tools import (evaluate_dyn_vars, - dynvar_deprecation, - node_deprecation, - abstract) -from brainpy._src.math.object_transform.variables import (Variable, VariableStack) -from brainpy._src.math.object_transform.naming import (get_unique_name, - get_stack_cache, - cache_stack) +from ._tools import (evaluate_dyn_vars, + dynvar_deprecation, + node_deprecation, + abstract) +from .variables import (Variable, VariableStack) +from .naming import (get_unique_name, + get_stack_cache, + cache_stack) from ._utils import infer_dyn_vars from .base import BrainPyObject, ArrayCollector, ObjectTransform diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index b1c7a2ab8..1a77ba9f0 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -13,7 +13,7 @@ import jax from brainpy import tools, check -from brainpy._src.math.object_transform.naming import get_stack_cache, cache_stack +from .naming import get_stack_cache, cache_stack from ._tools import dynvar_deprecation, node_deprecation, evaluate_dyn_vars, abstract from .base import BrainPyObject, ObjectTransform from .variables import Variable, VariableStack @@ -151,7 +151,6 @@ def __repr__(self): ''' - def jit( func: Callable = None, @@ -169,7 +168,8 @@ def jit( dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None, child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None, ) -> JITTransform: - """JIT (Just-In-Time) compilation for BrainPy computation. + """ + JIT (Just-In-Time) compilation for BrainPy computation. This function has the same ability to just-in-time compile a pure function, but it can also JIT compile a :py:class:`brainpy.DynamicalSystem`, or a @@ -220,7 +220,7 @@ def jit( ------- func : JITTransform A callable jitted function, set up for just-in-time compilation. - """.format(jit_par=_jit_par.strip()) + """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -255,6 +255,9 @@ def jit( abstracted_axes=abstracted_axes) +jit.__doc__ = jit.__doc__.format(jit_par=_jit_par.strip()) + + def cls_jit( func: Callable = None, static_argnums: Union[int, Iterable[int], None] = None, @@ -297,7 +300,7 @@ def cls_jit( ------- func : JITTransform A callable jitted function, set up for just-in-time compilation. - """.format(jit_pars=_jit_par) + """ if func is None: return lambda f: _make_jit_fun(fun=f, static_argnums=static_argnums, @@ -316,6 +319,9 @@ def cls_jit( abstracted_axes=abstracted_axes) +cls_jit.__doc__ = cls_jit.__doc__.format(jit_pars=_jit_par) + + def _make_jit_fun( fun: Callable, static_argnums: Union[int, Iterable[int], None] = None, diff --git a/brainpy/math/object_base.py b/brainpy/math/object_base.py index 982ae84e1..1f19459c0 100644 --- a/brainpy/math/object_base.py +++ b/brainpy/math/object_base.py @@ -2,9 +2,6 @@ from brainpy._src.math.object_transform.base import (BrainPyObject as BrainPyObject, FunAsObject as FunAsObject) -from brainpy._src.math.object_transform.function import ( - Partial as Partial, -) from brainpy._src.math.object_transform.base import (NodeList as NodeList, NodeDict as NodeDict,) from brainpy._src.math.object_transform.variables import (Variable as Variable, diff --git a/brainpy/math/object_transform.py b/brainpy/math/object_transform.py index b64761a68..9bcf8c763 100644 --- a/brainpy/math/object_transform.py +++ b/brainpy/math/object_transform.py @@ -9,10 +9,6 @@ hessian as hessian, ) -from brainpy._src.math.object_transform.base import ( - ObjectTransform as ObjectTransform -) - from brainpy._src.math.object_transform.controls import ( make_loop as make_loop, make_while as make_while, @@ -23,13 +19,14 @@ while_loop as while_loop, ) -from brainpy._src.math.object_transform.function import ( - to_object as to_object, - function as function, -) from brainpy._src.math.object_transform.jit import ( jit as jit, cls_jit, ) + +from brainpy._src.math.object_transform.function import ( + to_object as to_object, + function as function, +) \ No newline at end of file diff --git a/docs/auto_generater.py b/docs/auto_generater.py index d2e42c634..70f297eb1 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -552,12 +552,12 @@ def generate_math_docs(): 'brainpy.math', 'apis/auto/math.rst', subsections={ - 'object_base': ('Basis for Object-oriented Transformations', 'brainpy.math'), + 'object_base': ('Objects and Variables', 'brainpy.math'), 'object_transform': ('Object-oriented Transformations', 'brainpy.math'), 'environment': ('Environment Settings', 'brainpy.math'), - 'compat_numpy': ('Array Operators with NumPy Syntax', 'brainpy.math'), - 'compat_pytorch': ('Array Operators with PyTorch Syntax', 'brainpy.math'), - 'compat_tensorflow': ('Array Operators with TensorFlow Syntax', 'brainpy.math'), + 'compat_numpy': ('Dense Operators with NumPy Syntax', 'brainpy.math'), + 'compat_pytorch': ('Dense Operators with PyTorch Syntax', 'brainpy.math'), + 'compat_tensorflow': ('Dense Operators with TensorFlow Syntax', 'brainpy.math'), 'interoperability': ('Array Interoperability', 'brainpy.math'), 'event_ops': ('Operators for Event-driven Computation', 'brainpy.math'), 'jitconn_ops': ('Operators for Just-In-Time Connectivity', 'brainpy.math'), From 77cbd768a2733aa946d7fdcf4ccf884cdfc707ad Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 11 Apr 2023 12:53:11 +0800 Subject: [PATCH 3/3] support syntax with `disable_jit` --- .../_src/math/object_transform/controls.py | 48 +++++++++++-------- brainpy/_src/math/object_transform/jit.py | 5 ++ .../object_transform/tests/test_controls.py | 26 ++++++++++ .../math/object_transform/tests/test_jit.py | 47 +++++++++++++++++- 4 files changed, 104 insertions(+), 22 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3ddc5d1e8..71f764bae 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -483,19 +483,15 @@ def cond( operands = (operands,) # dyn vars - if dyn_vars is None: - dyn_vars = evaluate_dyn_vars(true_fun, *operands) - dyn_vars += evaluate_dyn_vars(false_fun, *operands) + dynvar_deprecation(dyn_vars) + node_deprecation(child_objs) + + if jax.config.jax_disable_jit: + dyn_vars = VariableStack() else: - dynvar_deprecation(dyn_vars) - node_deprecation(child_objs) - dyn_vars = check.is_all_vars(dyn_vars, out_as='dict') - dyn_vars = ArrayCollector(dyn_vars) - dyn_vars.update(infer_dyn_vars(true_fun)) - dyn_vars.update(infer_dyn_vars(false_fun)) - for obj in check.is_all_objs(child_objs, out_as='tuple'): - dyn_vars.update(obj.vars().unique()) + dyn_vars = evaluate_dyn_vars(true_fun, *operands) + dyn_vars += evaluate_dyn_vars(false_fun, *operands) # TODO: cache mechanism? if len(dyn_vars) > 0: @@ -746,14 +742,19 @@ def for_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - # TODO: better cache mechanism? dyn_vars = get_stack_cache(body_fun) - if dyn_vars is None: - with jax.ensure_compile_time_eval(): - op_vals = jax.tree_util.tree_map(_loop_abstractify, operands) - with VariableStack() as dyn_vars: - _ = jax.eval_shape(body_fun, *op_vals) - cache_stack(body_fun, dyn_vars) # cache + if not jit: + if dyn_vars is None: + dyn_vars = VariableStack() + + else: + # TODO: better cache mechanism? + if dyn_vars is None: + with jax.ensure_compile_time_eval(): + op_vals = jax.tree_util.tree_map(_loop_abstractify, operands) + with VariableStack() as dyn_vars: + _ = jax.eval_shape(body_fun, *op_vals) + cache_stack(body_fun, dyn_vars) # cache # functions def fun2scan(carry, x): @@ -762,7 +763,8 @@ def fun2scan(carry, x): results = body_fun(*x) return dyn_vars.dict_data(), results - if remat: fun2scan = jax.checkpoint(fun2scan) + if remat: + fun2scan = jax.checkpoint(fun2scan) # TODO: cache mechanism? with jax.disable_jit(not jit): @@ -851,8 +853,12 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - dyn_vars = evaluate_dyn_vars(body_fun, *operands) - dyn_vars += evaluate_dyn_vars(cond_fun, *operands) + if jax.config.jax_disable_jit: + dyn_vars = VariableStack() + + else: + dyn_vars = evaluate_dyn_vars(body_fun, *operands) + dyn_vars += evaluate_dyn_vars(cond_fun, *operands) def _body_fun(op): dyn_vals, old_vals = op diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 1a77ba9f0..9f6d85646 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -78,6 +78,9 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): return out, changes def __call__(self, *args, **kwargs): + if jax.config.jax_disable_jit: + return self.fun(*args, **kwargs) + if self._transform is None: self._dyn_vars = evaluate_dyn_vars(self.fun, *args, **kwargs) self._transform = jax.jit( @@ -334,6 +337,8 @@ def _make_jit_fun( @wraps(fun) def call_fun(self, *args, **kwargs): fun2 = partial(fun, self) + if jax.config.jax_disable_jit: + return fun2(*args, **kwargs) cache = get_stack_cache(fun2) # TODO: better cache mechanism if cache is None: with jax.ensure_compile_time_eval(): diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py index f55d09bae..bd6c09f90 100644 --- a/brainpy/_src/math/object_transform/tests/test_controls.py +++ b/brainpy/_src/math/object_transform/tests/test_controls.py @@ -3,6 +3,7 @@ import unittest from functools import partial +import jax from absl.testing import parameterized from jax._src import test_util as jtu @@ -199,3 +200,28 @@ def body(x, y): res = bm.while_loop(body, cond, operands=(1., 1.)) print() print(res) + + def test2(self): + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + def cond(x, y): + return x < 6. + + def body(x, y): + a.value += x + b.value *= y + return x + b[0], y + 1. + + res = bm.while_loop(body, cond, operands=(1., 1.)) + print() + print(res) + + with jax.disable_jit(): + a = bm.Variable(bm.zeros(1)) + b = bm.Variable(bm.ones(1)) + + res2 = bm.while_loop(body, cond, operands=(1., 1.)) + self.assertTrue(bm.array_equal(res2[0], res[0])) + self.assertTrue(bm.array_equal(res2[1], res[1])) + diff --git a/brainpy/_src/math/object_transform/tests/test_jit.py b/brainpy/_src/math/object_transform/tests/test_jit.py index a8588068e..f8691c80d 100644 --- a/brainpy/_src/math/object_transform/tests/test_jit.py +++ b/brainpy/_src/math/object_transform/tests/test_jit.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import jax import brainpy as bp import brainpy.math as bm @@ -23,6 +23,25 @@ def __call__(self, *args, **kwargs): b_out = bm.jit(program)() self.assertTrue(bm.array_equal(b_out, program.b)) + def test_jaxarray_inside_jit1_disable(self): + class SomeProgram(bp.BrainPyObject): + def __init__(self): + super(SomeProgram, self).__init__() + self.a = bm.zeros(2) + self.b = bm.Variable(bm.ones(2)) + + def __call__(self, *args, **kwargs): + a = bm.random.uniform(size=2) + a = a.at[0].set(1.) + self.b += a + return self.b.value + + program = SomeProgram() + with jax.disable_jit(): + b_out = bm.jit(program)() + self.assertTrue(bm.array_equal(b_out, program.b)) + print(b_out) + def test_class_jit1(self): class SomeProgram(bp.BrainPyObject): def __init__(self): @@ -47,5 +66,31 @@ def update(self, x): program.update(1.) self.assertTrue(bm.allclose(new_b + 1., program.b)) + def test_class_jit1_with_disable(self): + class SomeProgram(bp.BrainPyObject): + def __init__(self): + super(SomeProgram, self).__init__() + self.a = bm.zeros(2) + self.b = bm.Variable(bm.ones(2)) + + @bm.cls_jit + def __call__(self): + a = bm.random.uniform(size=2) + a = a.at[0].set(1.) + self.b += a + return self.b.value + + @bm.cls_jit(inline=True) + def update(self, x): + self.b += x + + program = SomeProgram() + with jax.disable_jit(): + new_b = program() + self.assertTrue(bm.allclose(new_b, program.b)) + with jax.disable_jit(): + program.update(1.) + self.assertTrue(bm.allclose(new_b + 1., program.b)) +