Skip to content

Commit

Permalink
msgpack checkpointing support brainstate.State (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Nov 15, 2024
1 parent 84daae7 commit f8a1a66
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 14 deletions.
42 changes: 31 additions & 11 deletions braintools/file/msg_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,12 @@ def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]:
return ys


register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(tuple,
_list_state_dict,
lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)))


def _dict_state_dict(xs: Dict[str, Any]) -> Dict[str, Any]:
str_keys = set(str(k) for k in xs.keys())
if len(str_keys) != len(xs):
Expand All @@ -297,6 +303,9 @@ def _restore_dict(xs, states: Dict[str, Any]) -> Dict[str, Any]:
for key, value in xs.items()}


register_serialization_state(dict, _dict_state_dict, _restore_dict)


def _namedtuple_state_dict(nt) -> Dict[str, Any]:
return {key: to_state_dict(getattr(nt, key)) for key in nt._fields}

Expand All @@ -320,25 +329,36 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]):
return type(xs)(**fields)


register_serialization_state(_NamedTuple,
_namedtuple_state_dict,
_restore_namedtuple)


def _quantity_dict_state(x: u.Quantity) -> Dict[str, jax.Array]:
return {'mantissa': x.mantissa, 'scale': x.unit.scale, 'base': x.unit.base, 'dim': x.unit.dim._dims}


def _restore_quantity(x: u.Quantity, state_dict: Dict) -> u.Quantity:
x.update_mantissa(state_dict['mantissa'])
assert x.unit == u.Unit(dim=u.Dimension(state_dict['dim']), scale=state_dict['scale'], base=state_dict['base'])
return x
unit = u.Unit(dim=u.Dimension(state_dict['dim']), scale=state_dict['scale'], base=state_dict['base'])
assert x.unit == unit
return u.Quantity(state_dict['mantissa'], unit=unit)


register_serialization_state(u.Quantity, _quantity_dict_state, _restore_quantity)
register_serialization_state(dict, _dict_state_dict, _restore_dict)
register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(tuple,
_list_state_dict,
lambda xs, state_dict: tuple(_restore_list(list(xs), state_dict)))
register_serialization_state(_NamedTuple,
_namedtuple_state_dict,
_restore_namedtuple)


def _brainstate_dict_state(x: bst.State) -> Dict[str, Any]:
return to_state_dict(x.value)


def _restore_brainstate(x: bst.State, state_dict: Dict) -> bst.State:
x.value = from_state_dict(x.value, state_dict)
return x


register_serialization_state(bst.State, _brainstate_dict_state, _restore_brainstate)


register_serialization_state(
jax.tree_util.Partial,
lambda x: {"args": to_state_dict(x.args),
Expand Down
28 changes: 26 additions & 2 deletions braintools/file/msg_checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


class TestMsgCheckpoint(unittest.TestCase):
def test_msg_checkpoint(self):
def test_checkpoint_quantity(self):
data = {
"name": bst.random.rand(3) * u.ms,
}
Expand All @@ -44,4 +44,28 @@ def test_msg_checkpoint(self):
data['name'] += 1 * u.ms

data2 = bts.file.msgpack_load(filename, target=data)
self.assertEqual(data, data2)
self.assertTrue('name' in data2)
self.assertTrue(isinstance(data2['name'], u.Quantity))
self.assertTrue(not u.math.allclose(data['name'], data2['name']))

def test_checkpoint_state(self):
data = {
"a": bst.State(bst.random.rand(1)),
"b": bst.ShortTermState(bst.random.rand(2)),
"c": bst.ParamState(bst.random.rand(3)),
}

with TemporaryDirectory() as tmpdirname:
filename = tmpdirname + "/test_msg_checkpoint.msg"
bts.file.msgpack_save(filename, data)

data2 = bts.file.msgpack_load(filename, target=data)
self.assertTrue('a' in data2)
self.assertTrue('b' in data2)
self.assertTrue('c' in data2)
self.assertTrue(isinstance(data2['a'], bst.State))
self.assertTrue(isinstance(data2['b'], bst.ShortTermState))
self.assertTrue(isinstance(data2['c'], bst.ParamState))
self.assertTrue(u.math.allclose(data['a'].value, data2['a'].value))
self.assertTrue(u.math.allclose(data['b'].value, data2['b'].value))
self.assertTrue(u.math.allclose(data['c'].value, data2['c'].value))
2 changes: 1 addition & 1 deletion braintools/visualize/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ def exclude(rc: RcParams, keys: list):
plt.style.core.update_nested_dict(plt.style.library, {'notebook2': style})
plt.style.core.available[:] = sorted(plt.style.library.keys())

except (ImportError, ModuleNotFoundError):
except Exception:
scienceplots = None

0 comments on commit f8a1a66

Please sign in to comment.