From c55cd4486069bf44e904c0ef344a791cb571fdba Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 4 Oct 2024 18:07:32 +0100 Subject: [PATCH] [nnx] add cloudpickle support --- flax/nnx/rnglib.py | 7 +++++ flax/nnx/tracers.py | 7 +++++ flax/nnx/variables.py | 7 +++++ pyproject.toml | 1 + tests/nnx/module_test.py | 31 +++++++++++++++++++ uv.lock | 65 +++++++++++++++++----------------------- 6 files changed, 81 insertions(+), 37 deletions(-) diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py index 01ad06983d..25b2eea4fb 100644 --- a/flax/nnx/rnglib.py +++ b/flax/nnx/rnglib.py @@ -230,6 +230,13 @@ def __len__(self) -> int: def __contains__(self, name: tp.Any) -> bool: return name in vars(self) + # pickle support + def __getstate__(self): + return vars(self).copy() + + def __setstate__(self, state): + vars(self).update(state) + class ForkStates(tp.NamedTuple): split_keys: State diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index cc78597395..ba56ba6cf4 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -62,3 +62,10 @@ def __eq__(self, other): return isinstance(other, TraceState) and self._jax_trace is other._jax_trace return isinstance(other, TraceState) and self._jax_trace == other._jax_trace + + # pickle support + def __getstate__(self): + return {} + + def __setstate__(self, state): + self._jax_trace = current_jax_trace() diff --git a/flax/nnx/variables.py b/flax/nnx/variables.py index ee6c8a003b..882eeb4a6c 100644 --- a/flax/nnx/variables.py +++ b/flax/nnx/variables.py @@ -404,6 +404,13 @@ def on_remove_axis( def __jax_array__(self): return self.value + # pickle support + def __getstate__(self): + return vars(self).copy() + + def __setstate__(self, state): + vars(self).update(state) + # -------------------------------------------- # proxy methods # -------------------------------------------- diff --git a/pyproject.toml b/pyproject.toml index 0b21a5c277..edcc4484f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ testing = [ "tensorflow>=2.12.0", # to fix Numpy np.bool8 deprecation error "torch", "treescope>=0.1.1; python_version>='3.10'", + "cloudpickle>=3.0.0", ] docs = [ "sphinx>=3.3.1", diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 2aff69a144..9bd583939b 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -14,9 +14,12 @@ from copy import deepcopy import dataclasses +import pickle +import tempfile from typing import TypeVar from absl.testing import absltest +import cloudpickle from flax import nnx, errors import jax import jax.numpy as jnp @@ -512,6 +515,34 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): raise_if_not_found=False, ) + def test_cloud_pickle(self): + class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.1, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) + + def __call__(self, x): + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) + + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization + model.eval() + + y1 = model(jnp.ones((5, 2))) + with tempfile.TemporaryDirectory() as tmpdir: + path = f'{tmpdir}/model.pkl' + with open(path, 'wb') as f: + cloudpickle.dump(model, f) + del model + with open(path, 'rb') as f: + model = pickle.load(f) + + self.assertIsInstance(model, Model) + y2 = model(jnp.ones((5, 2))) + np.testing.assert_allclose(y1, y2) + class TestModulePytree: def test_tree_map(self): diff --git a/uv.lock b/uv.lock index 5dbc9e8070..bbd346cb34 100644 --- a/uv.lock +++ b/uv.lock @@ -504,7 +504,7 @@ wheels = [ [package.optional-dependencies] toml = [ - { name = "tomli", marker = "python_full_version == '3.11'" }, + { name = "tomli", marker = "python_full_version <= '3.11'" }, ] [[package]] @@ -767,7 +767,7 @@ wheels = [ [[package]] name = "flax" -version = "0.8.6" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "jax" }, @@ -807,6 +807,7 @@ docs = [ { name = "sphinx-design" }, ] testing = [ + { name = "cloudpickle" }, { name = "clu" }, { name = "einops" }, { name = "gymnasium", extra = ["accept-rom-license", "atari"] }, @@ -831,6 +832,7 @@ testing = [ [package.metadata] requires-dist = [ + { name = "cloudpickle", marker = "extra == 'testing'", specifier = ">=3.0.0" }, { name = "clu", marker = "python_full_version < '3.10' and extra == 'testing'", specifier = "<=0.0.9" }, { name = "clu", marker = "extra == 'testing'" }, { name = "dm-haiku", marker = "extra == 'docs'" }, @@ -1217,7 +1219,7 @@ wheels = [ [[package]] name = "jax" -version = "0.4.31" +version = "0.4.34" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -1226,14 +1228,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/73/e4/c1a4c0e7dafbc53fff9f42f9c1bf5918dabd1f91325512d6b382bea8750b/jax-0.4.31.tar.gz", hash = "sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287", size = 1743359 } +sdist = { url = "https://files.pythonhosted.org/packages/19/6a/cacfcdf77841a4562e555ef35e0dbc5f8ca79c9f1010aaa4cf3973e79c69/jax-0.4.34.tar.gz", hash = "sha256:44196854f40c5f9cea3142824b9f1051f85afc3fcf7593ec5479fc8db01c58db", size = 1848472 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/cf/5f51b43bd692e90585c0ef6e8d1b0db5d254fe0224a6570daa59a1be014f/jax-0.4.31-py3-none-any.whl", hash = "sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7", size = 2038969 }, + { url = "https://files.pythonhosted.org/packages/06/f3/c499d358dd7f267a63d7d38ef54aadad82e28d2c28bafff15360c3091946/jax-0.4.34-py3-none-any.whl", hash = "sha256:b957ca1fc91f7343f91a186af9f19c7f342c946f95a8c11c7f1e5cdfe2e58d9e", size = 2144294 }, ] [[package]] name = "jaxlib" -version = "0.4.31" +version = "0.4.34" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -1241,21 +1243,26 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/72/12c267f6775aac7e3ca6ed882c9816883cce44d73169d25d0e0b0f1f6972/jaxlib-0.4.31-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:48ea73cb78341bd4aabbb15e1a076ed61505ec80ab8eb4810e2d34758c400f80", size = 88767265 }, - { url = "https://files.pythonhosted.org/packages/b2/c9/0a6a964a852b66cff6108b8d8bc17115b69171fa6a22a916bc911d9f0a61/jaxlib-0.4.31-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bacb86012f9104dd71706266420fd1e5d179d826d0635c95fe31506d605b4537", size = 70040016 }, - { url = "https://files.pythonhosted.org/packages/ae/4d/71e6286f88bf2c516e8af26a4245b8a68b12fcf1bbb42a4b3b7958575407/jaxlib-0.4.31-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:d019023f71dba65127a3016ddc755de4b30f5bc9bd5b632a716a5fb3b00c5e53", size = 73050144 }, - { url = "https://files.pythonhosted.org/packages/cd/d7/918ac5477d1c32329c43bc2eb40473baa1c244851c825904430e8911f15a/jaxlib-0.4.31-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:1b8e9e6970ecc08bd8b4d80c03d882f4dcd4ac119cb2959811ebc58fce1c263d", size = 88131641 }, - { url = "https://files.pythonhosted.org/packages/ed/ea/2ba944ba4365cf8f043ff34cdb9704e29a37478b75592d03672fbba4d0df/jaxlib-0.4.31-cp310-cp310-win_amd64.whl", hash = "sha256:d3540a557c188d23ef93760da482b158ca910124a0445263c3b17c09c114538a", size = 56281724 }, - { url = "https://files.pythonhosted.org/packages/46/d0/100199575992545940afc17e62dea5a79c15ef738c1ae9784a1838962aa4/jaxlib-0.4.31-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1fd838ff91ea58ec2bdc7b4ecbb921ad501a318fafdeae120e6e7f88f5c20b17", size = 88768971 }, - { url = "https://files.pythonhosted.org/packages/18/ea/eddfae920bf689314aa0302a4c841cfac01b6cfd77f60f1a3f3dd355fddc/jaxlib-0.4.31-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:86340df8b37729f6fc5742f17761857bb9e59c418c9453e9b090f49f6194cdf9", size = 70038216 }, - { url = "https://files.pythonhosted.org/packages/a6/ce/ce7d3ba4790e18f67cfcb4552056dd04350085116f4754333f481516d97c/jaxlib-0.4.31-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:2d2639d210b0b1918dfaabbcc504fc668326e1a6fd1f0eb427c40b039188bbce", size = 73050770 }, - { url = "https://files.pythonhosted.org/packages/32/33/6d30bf3ec7d590a8dc0f1e30ea4c531b6f6a33116eb2066e354b485066de/jaxlib-0.4.31-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:1db6f8ea35b884f9e7761b006ee9c60ed05be6c75d2e527551f74579cbe11677", size = 88130221 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/5b7d20ed550d156311587eee6e44c48971fe6c3b43f39e82dacda3875396/jaxlib-0.4.31-cp311-cp311-win_amd64.whl", hash = "sha256:ceec494df08aaf65b8bbcbd40dd21a6579fa76ca5b851cce46fd7ce0388c0449", size = 56279795 }, - { url = "https://files.pythonhosted.org/packages/fa/27/3eee15d1b168d434498c388780114d7629f715e19c2d08754ab4be82ad2d/jaxlib-0.4.31-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d", size = 88818529 }, - { url = "https://files.pythonhosted.org/packages/68/cf/28895a4a89d88d18592507d7a35218b6bb2d8bced13615065c9f925f2ae1/jaxlib-0.4.31-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9f89c185287e40ee8173a7142d6495311e772cd139a93dca93f0d99c1872832", size = 70079551 }, - { url = "https://files.pythonhosted.org/packages/e0/af/10b49f8de2acc7abc871478823579d7241be52ca0d6bb0d2b2c476cc1b68/jaxlib-0.4.31-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:4d867a1a0565b31cfdaabbec81e0302c6461bb2ac4b92c04670328d795819803", size = 73053401 }, - { url = "https://files.pythonhosted.org/packages/b1/09/58d35465d48c8bee1d9a4e7a3c5db2edaabfc7ac94f4576c9f8c51b83e70/jaxlib-0.4.31-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:1f1afa5fd58a60f67f0ca586e26714aece62eaa2c8334c24d0e8285afc4a7ccd", size = 88162291 }, - { url = "https://files.pythonhosted.org/packages/c8/13/1bb2bcb4d9f4719dd5f3d98f5c2fc2235f961ced576366b040372eebdb17/jaxlib-0.4.31-cp312-cp312-win_amd64.whl", hash = "sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072", size = 56299104 }, + { url = "https://files.pythonhosted.org/packages/24/31/2e254fe2fc23201775a7d0ccd1bcde892cfa349eb805744b81b15e0dcf74/jaxlib-0.4.34-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:b7a212a3cb5c6acc201c32ae4f4b5f5a9ac09457fbb77ba8db5ce7e7d4adc214", size = 87399257 }, + { url = "https://files.pythonhosted.org/packages/1e/67/6a344c357caad33e84b871925cd043b4218fc13a427266d1a1dedcb1c095/jaxlib-0.4.34-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:45d719a2ce0ebf21255a277b71d756f3609b7b5be70cddc5d88fd58c35219de0", size = 67617952 }, + { url = "https://files.pythonhosted.org/packages/dd/ea/12c836126419ca80248228f2236831617eedb1e3640c34c942606f33bb08/jaxlib-0.4.34-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:3e60bc826933082e99b19b87c21818a8d26fcdb01f418d47cedff554746fd6cc", size = 69391770 }, + { url = "https://files.pythonhosted.org/packages/e4/b0/a5bd34643c070e50829beec217189eab1acdfea334df1f9ddb4e5f8bec0f/jaxlib-0.4.34-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:d840e64b85f8865404d6d225b9bb340e158df1457152a361b05680e24792b232", size = 86094116 }, + { url = "https://files.pythonhosted.org/packages/d8/c9/35a4233fe74ddd5aabe89aac1b3992b0e463982564252d21fd263d4d9992/jaxlib-0.4.34-cp310-cp310-win_amd64.whl", hash = "sha256:b0001c8f0e2b1c7bc99e4f314b524a340d25653505c1a1484d4041a9d3617f6f", size = 55206389 }, + { url = "https://files.pythonhosted.org/packages/bf/14/00a3385532d72ab51bd8e9f8c3e19a2e257667955565e9fc10236771dd06/jaxlib-0.4.34-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:8ee3f93836e53c86556ccd9449a4ea43516ee05184d031a71dd692e81259f7d9", size = 87420889 }, + { url = "https://files.pythonhosted.org/packages/66/78/d1535ee73fe505dc6c8831c19c4846afdce7df5acefb9f8ee885aa73d700/jaxlib-0.4.34-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c9d3adcae43a33aad4332be9c2aedc5ef751d1e755f917a5afb30c7872eacaa8", size = 67635880 }, + { url = "https://files.pythonhosted.org/packages/aa/06/3e09e794acf308e170905d732eca0d041449503c47505cc22e8ef78a989d/jaxlib-0.4.34-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:571ef03259835458111596a71a2f4a6fabf4ec34595df4cea555035362ac5bf0", size = 69421901 }, + { url = "https://files.pythonhosted.org/packages/c7/d0/6bc81c0b1d507f403e6085ce76a429e6d7f94749d742199252e299dd1424/jaxlib-0.4.34-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:3bcfa639ca3cfaf86c8ceebd5fc0d47300fd98a078014a1d0cc03133e1523d5f", size = 86114491 }, + { url = "https://files.pythonhosted.org/packages/9d/5d/7e71019af5f6fdebe6c10eab97d01f44b931d94609330da9e142cb155f8c/jaxlib-0.4.34-cp311-cp311-win_amd64.whl", hash = "sha256:133070d4fec5525ffea4dc72956398c1cf647a04dcb37f8a935ee82af78d9965", size = 55241262 }, + { url = "https://files.pythonhosted.org/packages/bc/42/5038983664494dfb50f8669a662d965d7ea62f9250e40d8cd36dcf9ac3dd/jaxlib-0.4.34-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c7b3e724a30426a856070aba0192b5d199e95b4411070e7ad96ad8b196877b10", size = 87473956 }, + { url = "https://files.pythonhosted.org/packages/87/2e/8a75d3107c019c370c50c01acc205da33f9d6fba830950401a772a8e9f6d/jaxlib-0.4.34-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:096f0ca309d41fa692a9d1f2f9baab1c5c8ca0749876ebb3f748e738a27c7ff4", size = 67650276 }, + { url = "https://files.pythonhosted.org/packages/af/09/cceae2d251a506b4297679d10ee9f5e905a6b992b0687d553c9470ffd1db/jaxlib-0.4.34-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:1a30771d85fa77f9ab8f18e63240f455ab3a3f87660ed7b8d5eea6ceecbe5c1e", size = 69431284 }, + { url = "https://files.pythonhosted.org/packages/e7/0d/4faf839e3c8ce2a5b615df64427be3e870899c72c0ebfb5859348150aba1/jaxlib-0.4.34-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:48272e9034ff868d4328cf0055a07882fd2be93f59dfb6283af7de491f9d1290", size = 86151183 }, + { url = "https://files.pythonhosted.org/packages/a4/bc/a38f99071fca6cc31ae949e508a23b0de5de559da594443bb625a1adb8f3/jaxlib-0.4.34-cp312-cp312-win_amd64.whl", hash = "sha256:901cb4040ed24eae40071d8114ea8d10dff436277fa74a1a5b9e7206f641151c", size = 55278745 }, + { url = "https://files.pythonhosted.org/packages/21/4e/fab0606683af7aa9284a32d2b188ff132cffb0ee3ea04d941a547eb776d1/jaxlib-0.4.34-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:72e22e99a5dc890a64443c3fc12f13f20091f578c405a76de077ba42b4c62cd7", size = 87474367 }, + { url = "https://files.pythonhosted.org/packages/3e/1b/709be16d543a3db5b471ee5e7d089c57484c386b08499923e43bd8da5d0b/jaxlib-0.4.34-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c303f5acaf6c56ce5ff133a923c9b6247bdebedde15bd2c893c24be4d8f71306", size = 67651281 }, + { url = "https://files.pythonhosted.org/packages/85/9e/f3801096cd4a2c764af7a1f6b683c769706602ea72b27ec35bacfcc4cd4f/jaxlib-0.4.34-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7be673a876ebd1aef440fb7e3ebaf99a91abeb550c9728c644b7d7c7b5d7c108", size = 69432987 }, + { url = "https://files.pythonhosted.org/packages/e6/79/61301f55b24c3a898ef9bc4e13600b66e3f838623fc6f87648ac1ccbca01/jaxlib-0.4.34-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:87f25a477cd279840e53718403f97092eba0e8a945fcab47bcf435b6f9119dda", size = 86152550 }, + { url = "https://files.pythonhosted.org/packages/16/b0/e682d02126e0062b58dec0f0851048592396f74c24b4a4412dce4ddbbadb/jaxlib-0.4.34-cp313-cp313-win_amd64.whl", hash = "sha256:6b43a974c5d91a19912d138f2658dd8dbb7d30dcdff5c961d896c673e872b611", size = 55279410 }, ] [[package]] @@ -2007,7 +2014,6 @@ version = "12.1.3.1" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/37/6d/121efd7382d5b0284239f4ab1fc1590d86d34ed4a4a2fdb13b30ca8e5740/nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728", size = 410594774 }, - { url = "https://files.pythonhosted.org/packages/c5/ef/32a375b74bea706c93deea5613552f7c9104f961b21df423f5887eca713b/nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906", size = 439918445 }, ] [[package]] @@ -2016,7 +2022,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/00/6b218edd739ecfc60524e585ba8e6b00554dd908de2c9c66c1af3e44e18d/nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e", size = 14109015 }, - { url = "https://files.pythonhosted.org/packages/d0/56/0021e32ea2848c24242f6b56790bd0ccc8bf99f973ca790569c6ca028107/nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4", size = 10154340 }, ] [[package]] @@ -2025,7 +2030,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/b6/9f/c64c03f49d6fbc56196664d05dba14e3a561038a81a638eeb47f4d4cfd48/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2", size = 23671734 }, - { url = "https://files.pythonhosted.org/packages/ad/1d/f76987c4f454eb86e0b9a0e4f57c3bf1ac1d13ad13cd1a4da4eb0e0c0ce9/nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed", size = 19331863 }, ] [[package]] @@ -2034,7 +2038,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/eb/d5/c68b1d2cdfcc59e72e8a5949a37ddb22ae6cade80cd4a57a84d4c8b55472/nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40", size = 823596 }, - { url = "https://files.pythonhosted.org/packages/9f/e2/7a2b4b5064af56ea8ea2d8b2776c0f2960d95c88716138806121ae52a9c9/nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344", size = 821226 }, ] [[package]] @@ -2046,7 +2049,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, - { url = "https://files.pythonhosted.org/packages/3f/d0/f90ee6956a628f9f04bf467932c0a25e5a7e706a684b896593c06c82f460/nvidia_cudnn_cu12-9.1.0.70-py3-none-win_amd64.whl", hash = "sha256:6278562929433d68365a07a4a1546c237ba2849852c0d4b2262a486e805b977a", size = 679925892 }, ] [[package]] @@ -2055,7 +2057,6 @@ version = "11.0.2.54" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/86/94/eb540db023ce1d162e7bea9f8f5aa781d57c65aed513c33ee9a5123ead4d/nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56", size = 121635161 }, - { url = "https://files.pythonhosted.org/packages/f7/57/7927a3aa0e19927dfed30256d1c854caf991655d847a4e7c01fe87e3d4ac/nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253", size = 121344196 }, ] [[package]] @@ -2064,7 +2065,6 @@ version = "10.3.2.106" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/44/31/4890b1c9abc496303412947fc7dcea3d14861720642b49e8ceed89636705/nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0", size = 56467784 }, - { url = "https://files.pythonhosted.org/packages/5c/97/4c9c7c79efcdf5b70374241d48cf03b94ef6707fd18ea0c0f53684931d0b/nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a", size = 55995813 }, ] [[package]] @@ -2078,7 +2078,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, - { url = "https://files.pythonhosted.org/packages/b8/80/8fca0bf819122a631c3976b6fc517c1b10741b643b94046bd8dd451522c5/nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5", size = 121643081 }, ] [[package]] @@ -2090,7 +2089,6 @@ dependencies = [ ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, - { url = "https://files.pythonhosted.org/packages/0f/95/48fdbba24c93614d1ecd35bc6bdc6087bd17cbacc3abc4b05a9c2a1ca232/nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a", size = 195414588 }, ] [[package]] @@ -2109,7 +2107,6 @@ source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/81/b3/e456a1b2d499bb84bdc6670bfbcf41ff3bac58bd2fae6880d62834641558/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_aarch64.whl", hash = "sha256:84fb38465a5bc7c70cbc320cfd0963eb302ee25a5e939e9f512bbba55b6072fb", size = 19252608 }, { url = "https://files.pythonhosted.org/packages/59/65/7ff0569494fbaea45ad2814972cc88da843d53cc96eb8554fcd0908941d9/nvidia_nvjitlink_cu12-12.6.20-py3-none-manylinux2014_x86_64.whl", hash = "sha256:562ab97ea2c23164823b2a89cb328d01d45cb99634b8c65fe7cd60d14562bd79", size = 19724950 }, - { url = "https://files.pythonhosted.org/packages/cb/ef/8f96c82e1cfcf6d5b770f7b043c3cc24841fc247b37629a7cc643dbf72a1/nvidia_nvjitlink_cu12-12.6.20-py3-none-win_amd64.whl", hash = "sha256:ed3c43a17f37b0c922a919203d2d36cbef24d41cc3e6b625182f8b58203644f6", size = 162012830 }, ] [[package]] @@ -2118,7 +2115,6 @@ version = "12.1.105" source = { registry = "https://pypi.org/simple" } wheels = [ { url = "https://files.pythonhosted.org/packages/da/d3/8057f0587683ed2fcd4dbfbdfdfa807b9160b809976099d36b8f60d08f03/nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5", size = 99138 }, - { url = "https://files.pythonhosted.org/packages/b8/d7/bd7cb2d95ac6ac6e8d05bfa96cdce69619f1ef2808e072919044c2d47a8c/nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82", size = 66307 }, ] [[package]] @@ -3408,9 +3404,7 @@ dependencies = [ { name = "tensorflow", marker = "platform_system != 'Darwin'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/a2/e3/33fc5957790cf4710e0a9116cf37c0a881eda673e5f8b569bfff5654a48c/tensorflow_text-2.17.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8eba0b5804235519b571c827c97337c332de270107f06af6d2171cdefdc4c6a0", size = 6109587 }, { url = "https://files.pythonhosted.org/packages/61/59/2090318555d98dc9dc868b3c585ada2e1139be538d954340726aa3d3899a/tensorflow_text-2.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f04c3f478f1885ad4c7380643a768a72a3de79e1f8f40d50b48cc1fbf73893", size = 5205819 }, - { url = "https://files.pythonhosted.org/packages/92/65/e2d3d9300173a0927e8b7e3cf9a35f9539e9269786c1e1d9d945223fe21a/tensorflow_text-2.17.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a9b9f9c8a06878714a14f4e086fa8122beb2e141f82d0aa5a8f6b8f9b694db51", size = 6109684 }, { url = "https://files.pythonhosted.org/packages/de/32/182ecf4eb1432942876d9b0b089625564084c5ed4d03c02ddf2872177e95/tensorflow_text-2.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:161c09380b090774ed721cdcce973194458708250d7dfbac7cb9ea8a3e9ac762", size = 5205866 }, ] @@ -3587,9 +3581,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 }, { url = "https://files.pythonhosted.org/packages/33/3e/a2f59384587eff6aeb7d37b6780de7fedd2214935e27520430ca9f5b7975/triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c", size = 209438883 }, { url = "https://files.pythonhosted.org/packages/fe/7b/7757205dee3628f75e7991021d15cd1bd0c9b044ca9affe99b50879fc0e1/triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb", size = 209464695 }, - { url = "https://files.pythonhosted.org/packages/15/67/84e5a4b7b45bdeb11da26a67dfa2b988c512abbcbcad8cbc30aa579051b2/triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230", size = 209380247 }, - { url = "https://files.pythonhosted.org/packages/ea/6b/1d72cc8a7379822dadf050474add7d8b73b02c35057446b6f17d27cb9ea2/triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e", size = 209442823 }, - { url = "https://files.pythonhosted.org/packages/ae/b2/048c9ecfdba0e6b0ae3c02eed2d9dd3e9e990a6d46da98555cf0c2232168/triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253", size = 209468633 }, ] [[package]]