Skip to content

Commit

Permalink
[nnx] add cloudpickle support
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 4, 2024
1 parent 2d64500 commit eb7ae72
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 18 deletions.
7 changes: 7 additions & 0 deletions flax/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,13 @@ def __len__(self) -> int:
def __contains__(self, name: tp.Any) -> bool:
return name in vars(self)

# pickle support
def __getstate__(self):
return vars(self).copy()

def __setstate__(self, state):
vars(self).update(state)


class ForkStates(tp.NamedTuple):
split_keys: State
Expand Down
7 changes: 7 additions & 0 deletions flax/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,10 @@ def __eq__(self, other):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace

return isinstance(other, TraceState) and self._jax_trace == other._jax_trace

# pickle support
def __getstate__(self):
return {}

def __setstate__(self, state):
self._jax_trace = current_jax_trace()
7 changes: 7 additions & 0 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,13 @@ def on_remove_axis(
def __jax_array__(self):
return self.value

# pickle support
def __getstate__(self):
return vars(self).copy()

def __setstate__(self, state):
vars(self).update(state)

# --------------------------------------------
# proxy methods
# --------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ testing = [
"tensorflow>=2.12.0", # to fix Numpy np.bool8 deprecation error
"torch",
"treescope>=0.1.1; python_version>='3.10'",
"cloudpickle>=3.0.0",
]
docs = [
"sphinx>=3.3.1",
Expand Down
31 changes: 31 additions & 0 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@

from copy import deepcopy
import dataclasses
import pickle
import tempfile
from typing import TypeVar

from absl.testing import absltest
import cloudpickle
from flax import nnx, errors
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -512,6 +515,34 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs):
raise_if_not_found=False,
)

def test_cloud_pickle(self):
class Model(nnx.Module):
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
self.linear = nnx.Linear(din, dmid, rngs=rngs)
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
self.dropout = nnx.Dropout(0.1, rngs=rngs)
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

def __call__(self, x):
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
return self.linear_out(x)

model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
model.eval()

y1 = model(jnp.ones((5, 2)))
with tempfile.TemporaryDirectory() as tmpdir:
path = f'{tmpdir}/model.pkl'
with open(path, 'wb') as f:
cloudpickle.dump(model, f)
del model
with open(path, 'rb') as f:
model = pickle.load(f)

self.assertIsInstance(model, Model)
y2 = model(jnp.ones((5, 2)))
np.testing.assert_allclose(y1, y2)


class TestModulePytree:
def test_tree_map(self):
Expand Down
22 changes: 4 additions & 18 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit eb7ae72

Please sign in to comment.