From eb7ae72650be15c39e00059634864f9f5c54496e Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 4 Oct 2024 18:07:32 +0100 Subject: [PATCH] [nnx] add cloudpickle support --- flax/nnx/rnglib.py | 7 +++++++ flax/nnx/tracers.py | 7 +++++++ flax/nnx/variables.py | 7 +++++++ pyproject.toml | 1 + tests/nnx/module_test.py | 31 +++++++++++++++++++++++++++++++ uv.lock | 22 ++++------------------ 6 files changed, 57 insertions(+), 18 deletions(-) diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 01ad06983d..25b2eea4fb 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -230,6 +230,13 @@ def __len__(self) -> int: def __contains__(self, name: tp.Any) -> bool: return name in vars(self) + # pickle support + def __getstate__(self): + return vars(self).copy() + + def __setstate__(self, state): + vars(self).update(state) + class ForkStates(tp.NamedTuple): split_keys: State diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index cc78597395..ba56ba6cf4 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -62,3 +62,10 @@ def __eq__(self, other): return isinstance(other, TraceState) and self._jax_trace is other._jax_trace return isinstance(other, TraceState) and self._jax_trace == other._jax_trace + + # pickle support + def __getstate__(self): + return {} + + def __setstate__(self, state): + self._jax_trace = current_jax_trace() diff --git a/flax/nnx/variables.py b/flax/nnx/variables.py index ee6c8a003b..882eeb4a6c 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variables.py @@ -404,6 +404,13 @@ def on_remove_axis( def __jax_array__(self): return self.value + # pickle support + def __getstate__(self): + return vars(self).copy() + + def __setstate__(self, state): + vars(self).update(state) + # -------------------------------------------- # proxy methods # -------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 0b21a5c277..edcc4484f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ testing = [ "tensorflow>=2.12.0", # to fix Numpy np.bool8 deprecation error "torch", "treescope>=0.1.1; python_version>='3.10'", + "cloudpickle>=3.0.0", ] docs = [ "sphinx>=3.3.1", diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 2aff69a144..9bd583939b 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -14,9 +14,12 @@ from copy import deepcopy import dataclasses +import pickle +import tempfile from typing import TypeVar from absl.testing import absltest +import cloudpickle from flax import nnx, errors import jax import jax.numpy as jnp @@ -512,6 +515,34 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): raise_if_not_found=False, ) + def test_cloud_pickle(self): + class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.1, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + model.eval() + + y1 = model(jnp.ones((5, 2))) + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/model.pkl' + with open(path, 'wb') as f: + cloudpickle.dump(model, f) + del model + with open(path, 'rb') as f: + model = pickle.load(f) + + self.assertIsInstance(model, Model) + y2 = model(jnp.ones((5, 2))) + np.testing.assert_allclose(y1, y2) + class TestModulePytree: def test_tree_map(self): diff --git a/uv.lock b/uv.lock index 5dbc9e8070..d4e3a1ea4a 100644 --- a/uv.lock +++ b/uv.lock @@ -504,7 +504,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version == '3.11'" }, + { name = "tomli", marker = "python_full_version <= '3.11'" }, ] [[package]] @@ -767,7 +767,7 @@ wheels = [ [[package]] name = "flax" -version = "0.8.6" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "jax" }, @@ -807,6 +807,7 @@ docs = [ { name = "sphinx-design" }, ] testing = [ + { name = "cloudpickle" }, { name = "clu" }, { name = "einops" }, { name = "gymnasium", extra = ["accept-rom-license", "atari"] }, @@ -831,6 +832,7 @@ testing = [ [package.metadata] requires-dist = [ + { name = "cloudpickle", marker = "extra == 'testing'", specifier = ">=3.0.0" }, { name = "clu", marker = "python_full_version < '3.10' and extra == 'testing'", specifier = "<=0.0.9" }, { name = "clu", marker = "extra == 'testing'" }, { name = "dm-haiku", marker = "extra == 'docs'" }, @@ -2007,7 +2009,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2016,7 +2017,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2025,7 +2025,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -2034,7 +2033,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2046,7 +2044,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -2055,7 +2052,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2064,7 +2060,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -2078,7 +2073,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2090,7 +2084,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2109,7 +2102,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/81/b3/e456a1b2d499bb84bdc6670bfbcf41ff3bac58bd2fae6880d62834641558/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb", size = 19252608 }, { url = "https://files.pythonhosted.org/packages/59/65/7ff0569494fbaea45ad2814972cc88da843d53cc96eb8554fcd0908941d9/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79", size = 19724950 }, - { url = "https://files.pythonhosted.org/packages/cb/ef/8f96c82e1cfcf6d5b770f7b043c3cc24841fc247b37629a7cc643dbf72a1/nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6", size = 162012830 }, ] [[package]] @@ -2118,7 +2110,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -3408,9 +3399,7 @@ dependencies = [ { name = "tensorflow", marker = "platform_system != 'Darwin'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/e3/33fc5957790cf4710e0a9116cf37c0a881eda673e5f8b569bfff5654a48c/tensorflow_text-2.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8eba0b5804235519b571c827c97337c332de270107f06af6d2171cdefdc4c6a0", size = 6109587 }, { url = "https://files.pythonhosted.org/packages/61/59/2090318555d98dc9dc868b3c585ada2e1139be538d954340726aa3d3899a/tensorflow_text-2.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f04c3f478f1885ad4c7380643a768a72a3de79e1f8f40d50b48cc1fbf73893", size = 5205819 }, - { url = "https://files.pythonhosted.org/packages/92/65/e2d3d9300173a0927e8b7e3cf9a35f9539e9269786c1e1d9d945223fe21a/tensorflow_text-2.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a9b9f9c8a06878714a14f4e086fa8122beb2e141f82d0aa5a8f6b8f9b694db51", size = 6109684 }, { url = "https://files.pythonhosted.org/packages/de/32/182ecf4eb1432942876d9b0b089625564084c5ed4d03c02ddf2872177e95/tensorflow_text-2.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161c09380b090774ed721cdcce973194458708250d7dfbac7cb9ea8a3e9ac762", size = 5205866 }, ] @@ -3587,9 +3576,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 }, { url = "https://files.pythonhosted.org/packages/33/3e/a2f59384587eff6aeb7d37b6780de7fedd2214935e27520430ca9f5b7975/triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c", size = 209438883 }, { url = "https://files.pythonhosted.org/packages/fe/7b/7757205dee3628f75e7991021d15cd1bd0c9b044ca9affe99b50879fc0e1/triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb", size = 209464695 }, - { url = "https://files.pythonhosted.org/packages/15/67/84e5a4b7b45bdeb11da26a67dfa2b988c512abbcbcad8cbc30aa579051b2/triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230", size = 209380247 }, - { url = "https://files.pythonhosted.org/packages/ea/6b/1d72cc8a7379822dadf050474add7d8b73b02c35057446b6f17d27cb9ea2/triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e", size = 209442823 }, - { url = "https://files.pythonhosted.org/packages/ae/b2/048c9ecfdba0e6b0ae3c02eed2d9dd3e9e990a6d46da98555cf0c2232168/triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253", size = 209468633 }, ] [[package]]