diff --git a/flake.nix b/flake.nix index f770095..48e3046 100644 --- a/flake.nix +++ b/flake.nix @@ -68,6 +68,12 @@ inherit (python-final) jaxlib; }; + orbax-checkpoint = callPackage ./nix/orbax-checkpoint.nix { + inherit (python-final) jax; + inherit (python-final) jaxlib; + inherit tensorstore; + }; + flax = callPackage ./nix/flax.nix { inherit (python-final) jax; inherit (python-final) jaxlib; @@ -86,6 +92,7 @@ inherit (python-final) optax; inherit (python-final) flax; inherit (python-final) mujoco; + inherit (python-final) orbax-checkpoint; inherit dm_env; inherit pytinyrenderer; inherit trimesh; diff --git a/nix/brax.nix b/nix/brax.nix index c5608c1..23dcfc2 100644 --- a/nix/brax.nix +++ b/nix/brax.nix @@ -21,6 +21,7 @@ , typing-extensions , flax , mujoco +, orbax-checkpoint }: @@ -29,8 +30,8 @@ buildPythonPackage rec { src = fetchFromGitHub { owner = "google"; repo = "brax"; - rev = "v0.9.1"; - hash = "sha256-tFoTsz+EEd35nO39/owBBKbJG1LnAGUZBoOJkYVuwlI="; + rev = "v0.10.5"; + hash = "sha256-Ek1j/tghkNOny6uPWM+WHlTB3eZI5yl3oXq4DdIEJv4="; }; nativeBuildInputs = [ @@ -67,5 +68,6 @@ buildPythonPackage rec { optax mujoco flax + orbax-checkpoint ]; } diff --git a/nix/orbax-checkpoint.nix b/nix/orbax-checkpoint.nix new file mode 100644 index 0000000..9742338 --- /dev/null +++ b/nix/orbax-checkpoint.nix @@ -0,0 +1,48 @@ +{ buildPythonPackage +, lib +, absl-py +, cached-property +, etils +, fetchPypi +, flit +, jax +, jaxlib +, msgpack +, nest-asyncio +, numpy +, pyyaml +, tensorflow +, tensorstore +}: + +buildPythonPackage rec { + pname = "orbax-checkpoint"; + version = "0.1.6"; + + src = fetchPypi { + inherit pname version; + sha256 = "sha256-lnh2eAr54Dk8C9hnW/nZP0vrz/9Vqvwo5FMfrkJqFsA="; + }; + format = "pyproject"; + + propagatedBuildInputs = [ + absl-py + cached-property + etils + flit + jax + jaxlib + msgpack + nest-asyncio + numpy + pyyaml + tensorflow + tensorstore + ]; + + meta = with lib; { + description = "Checkpointing library for JAX-based models"; + license = licenses.asl20; + homepage = "https://github.com/google/orbax"; + }; +} diff --git a/nix/tensorstore.nix b/nix/tensorstore.nix index b387f5f..b1f9f77 100644 --- a/nix/tensorstore.nix +++ b/nix/tensorstore.nix @@ -5,14 +5,21 @@ buildPythonPackage rec { name = "tensorstore"; + src = fetchFromGitHub { owner = "google"; repo = "tensorstore"; - rev = "v0.1.35"; - hash = "sha256-VmJHDoU+lDS3PT4cEDZVDY+VYTa0F2X9aUWIEZW29vM="; + rev = "v0.1.60"; + hash = "sha256-rT0R1x51xHAElPwernUjBIIneRhncnsohMRAIhXyaYk="; }; format = "pyproject"; + propagatedBuildInputs = [ setuptools ]; + + preConfigure = '' + export HOME=$PWD + ''; + }