Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New OO transforms support jax.disable_jit mode #359

Merged
merged 3 commits into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.3.8"
__version__ = "2.4.0"


# fundamental supporting modules
Expand Down
100 changes: 50 additions & 50 deletions brainpy/_src/checkpoints/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, *args, **kwargs):

rng = bm.random.RandomState()

class IO1(bp.dyn.DynamicalSystem):
class IO1(bp.DynamicalSystem):
def __init__(self):
super(IO1, self).__init__()

Expand All @@ -21,7 +21,7 @@ def __init__(self):
self.c = bm.Variable(bm.ones((3, 4)))
self.d = bm.Variable(bm.ones((2, 3, 4)))

class IO2(bp.dyn.DynamicalSystem):
class IO2(bp.DynamicalSystem):
def __init__(self):
super(IO2, self).__init__()

Expand All @@ -35,59 +35,59 @@ def __init__(self):
io2.a2 = io1.a
io2.b2 = io2.b

self.net = bp.dyn.Container(io1, io2)
self.net = bp.Container(io1, io2)

print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
bp.checkpoints.io.load_by_npz('io_test_tmp.npz', self.net, verbose=True)

bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
bp.checkpoints.io.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
bp.checkpoints.io.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)

def test_npz_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
bp.checkpoints.io.save_as_npz('io_test_tmp.npz2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
bp.checkpoints.io.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)

def test_pkl(self):
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl', self.net.vars())
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)

bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
bp.checkpoints.io.save_as_pkl('io_test_tmp.pickle', self.net.vars())
bp.checkpoints.io.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)

def test_pkl_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)

def test_mat(self):
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
bp.checkpoints.io.save_as_mat('io_test_tmp.mat', self.net.vars())
bp.checkpoints.io.load_by_mat('io_test_tmp.mat', self.net, verbose=True)

def test_mat_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
bp.checkpoints.io.save_as_mat('io_test_tmp.mat2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)
bp.checkpoints.io.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)


class TestIO2(unittest.TestCase):
Expand All @@ -96,7 +96,7 @@ def __init__(self, *args, **kwargs):

rng = bm.random.RandomState()

class IO1(bp.dyn.DynamicalSystem):
class IO1(bp.DynamicalSystem):
def __init__(self):
super(IO1, self).__init__()

Expand All @@ -105,7 +105,7 @@ def __init__(self):
self.c = bm.Variable(bm.ones((3, 4)))
self.d = bm.Variable(bm.ones((2, 3, 4)))

class IO2(bp.dyn.DynamicalSystem):
class IO2(bp.DynamicalSystem):
def __init__(self):
super(IO2, self).__init__()

Expand All @@ -115,56 +115,56 @@ def __init__(self):
io1 = IO1()
io2 = IO2()

self.net = bp.dyn.Container(io1, io2)
self.net = bp.Container(io1, io2)

print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
bp.checkpoints.io.load_by_npz('io_test_tmp.npz', self.net, verbose=True)

bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
bp.checkpoints.io.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
bp.checkpoints.io.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)

def test_npz_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
bp.checkpoints.io.save_as_npz('io_test_tmp.npz2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
bp.checkpoints.io.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)

def test_pkl(self):
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl', self.net.vars())
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)

bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
bp.checkpoints.io.save_as_pkl('io_test_tmp.pickle', self.net.vars())
bp.checkpoints.io.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)

def test_pkl_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
bp.checkpoints.io.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
bp.checkpoints.io.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)

def test_mat(self):
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
bp.checkpoints.io.save_as_mat('io_test_tmp.mat', self.net.vars())
bp.checkpoints.io.load_by_mat('io_test_tmp.mat', self.net, verbose=True)

def test_mat_postfix(self):
with self.assertRaises(ValueError):
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
bp.checkpoints.io.save_as_mat('io_test_tmp.mat2', self.net.vars())
with self.assertRaises(ValueError):
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)
bp.checkpoints.io.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)
10 changes: 5 additions & 5 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

from brainpy import tools, check
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.variables import Variable
from brainpy._src.math.object_transform.base import BrainPyObject, ObjectTransform
from brainpy._src.math.object_transform._tools import (dynvar_deprecation,
node_deprecation,
evaluate_dyn_vars)
from .variables import Variable
from .base import BrainPyObject, ObjectTransform
from ._tools import (dynvar_deprecation,
node_deprecation,
evaluate_dyn_vars)

__all__ = [
'grad', # gradient of scalar function
Expand Down
64 changes: 35 additions & 29 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
from brainpy import errors, tools, check
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import (Array, )
from brainpy._src.math.object_transform._tools import (evaluate_dyn_vars,
dynvar_deprecation,
node_deprecation,
abstract)
from brainpy._src.math.object_transform.variables import (Variable, VariableStack)
from brainpy._src.math.object_transform.naming import (get_unique_name,
get_stack_cache,
cache_stack)
from ._tools import (evaluate_dyn_vars,
dynvar_deprecation,
node_deprecation,
abstract)
from .variables import (Variable, VariableStack)
from .naming import (get_unique_name,
get_stack_cache,
cache_stack)
from ._utils import infer_dyn_vars
from .base import BrainPyObject, ArrayCollector, ObjectTransform

Expand Down Expand Up @@ -483,19 +483,15 @@ def cond(
operands = (operands,)

# dyn vars
if dyn_vars is None:
dyn_vars = evaluate_dyn_vars(true_fun, *operands)
dyn_vars += evaluate_dyn_vars(false_fun, *operands)
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)

if jax.config.jax_disable_jit:
dyn_vars = VariableStack()

else:
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
dyn_vars = ArrayCollector(dyn_vars)
dyn_vars.update(infer_dyn_vars(true_fun))
dyn_vars.update(infer_dyn_vars(false_fun))
for obj in check.is_all_objs(child_objs, out_as='tuple'):
dyn_vars.update(obj.vars().unique())
dyn_vars = evaluate_dyn_vars(true_fun, *operands)
dyn_vars += evaluate_dyn_vars(false_fun, *operands)

# TODO: cache mechanism?
if len(dyn_vars) > 0:
Expand Down Expand Up @@ -746,14 +742,19 @@ def for_loop(
if not isinstance(operands, (list, tuple)):
operands = (operands,)

# TODO: better cache mechanism?
dyn_vars = get_stack_cache(body_fun)
if dyn_vars is None:
with jax.ensure_compile_time_eval():
op_vals = jax.tree_util.tree_map(_loop_abstractify, operands)
with VariableStack() as dyn_vars:
_ = jax.eval_shape(body_fun, *op_vals)
cache_stack(body_fun, dyn_vars) # cache
if not jit:
if dyn_vars is None:
dyn_vars = VariableStack()

else:
# TODO: better cache mechanism?
if dyn_vars is None:
with jax.ensure_compile_time_eval():
op_vals = jax.tree_util.tree_map(_loop_abstractify, operands)
with VariableStack() as dyn_vars:
_ = jax.eval_shape(body_fun, *op_vals)
cache_stack(body_fun, dyn_vars) # cache

# functions
def fun2scan(carry, x):
Expand All @@ -762,7 +763,8 @@ def fun2scan(carry, x):
results = body_fun(*x)
return dyn_vars.dict_data(), results

if remat: fun2scan = jax.checkpoint(fun2scan)
if remat:
fun2scan = jax.checkpoint(fun2scan)

# TODO: cache mechanism?
with jax.disable_jit(not jit):
Expand Down Expand Up @@ -851,8 +853,12 @@ def while_loop(
if not isinstance(operands, (list, tuple)):
operands = (operands,)

dyn_vars = evaluate_dyn_vars(body_fun, *operands)
dyn_vars += evaluate_dyn_vars(cond_fun, *operands)
if jax.config.jax_disable_jit:
dyn_vars = VariableStack()

else:
dyn_vars = evaluate_dyn_vars(body_fun, *operands)
dyn_vars += evaluate_dyn_vars(cond_fun, *operands)

def _body_fun(op):
dyn_vals, old_vals = op
Expand Down
Loading