Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): checkpoint I/O #4236

Merged
merged 18 commits into from
Oct 24, 2024
17 changes: 12 additions & 5 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ class JAXBackend(Backend):
name = "JAX"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature(0)
Backend.Feature.IO
# Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
# | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = []
suffixes: ClassVar[list[str]] = [".jax"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down Expand Up @@ -94,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
Expand All @@ -105,4 +108,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
86 changes: 86 additions & 0 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from pathlib import (
Path,
)

import orbax.checkpoint as ocp

from deepmd.jax.env import (
jax,
nnx,
)
from deepmd.jax.model.model import (
BaseModel,
get_model,
)
from deepmd.jax.utils.network import (
ArrayAPIParam,
)


def deserialize_to_file(model_file: str, data: dict) -> None:
"""Deserialize the dictionary to a model file.

Parameters
----------
model_file : str
The model file to be saved.
data : dict
The dictionary to be deserialized.
"""
if model_file.endswith(".jax"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
state = nnx.state(model, ArrayAPIParam)
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
checkpointer.save(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardSave(state),
model_def_script=ocp.args.JsonSave(model_def_script),
),
)
else:
raise ValueError("JAX backend only supports converting .jax directory")


def serialize_from_file(model_file: str) -> dict:
"""Serialize the model file to a dictionary.

Parameters
----------
model_file : str
The model file to be serialized.

Returns
-------
dict
The serialized model data.
"""
if model_file.endswith(".jax"):
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
data = checkpointer.restore(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardRestore(),
model_def_script=ocp.args.JsonRestore(),
),
)
state = data.state
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Fixed Show fixed Hide fixed
model_def_script = data.model_def_script
model = get_model(model_def_script)
model_dict = model.serialize()
data = {
"backend": "JAX",
"jax_version": jax.__version__,
"model": model_dict,
"model_def_script": model_def_script,
"@variables": {},
}
return data
else:
raise ValueError("JAX backend only supports converting .jax directory")
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ cu12 = [
jax = [
'jax>=0.4.33;python_version>="3.10"',
'flax>=0.8.0;python_version>="3.10"',
'orbax-checkpoint;python_version>="3.10"',
'jax-ai-stack;python_version>="3.10"',
]

[tool.deepmd_build_backend.scripts]
Expand Down Expand Up @@ -402,6 +404,7 @@ banned-module-level-imports = [
# Also ignore `E402` in all `__init__.py` files.
"deepmd/tf/**" = ["TID253"]
"deepmd/pt/**" = ["TID253"]
"deepmd/jax/**" = ["TID253"]
"source/tests/tf/**" = ["TID253"]
"source/tests/pt/**" = ["TID253"]
"source/tests/universal/pt/**" = ["TID253"]
Expand Down
8 changes: 6 additions & 2 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import shutil
import unittest
from pathlib import (
Path,
Expand Down Expand Up @@ -60,12 +61,14 @@ def save_data_to_model(self, model_file: str, data: dict) -> None:
def tearDown(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for ii in Path(".").glob(prefix + ".*"):
if Path(ii).exists():
if Path(ii).is_file():
Path(ii).unlink()
elif Path(ii).is_dir():
shutil.rmtree(ii)

def test_data_equal(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for backend_name in ("tensorflow", "pytorch", "dpmodel"):
for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):
with self.subTest(backend_name=backend_name):
backend = Backend.get_backend(backend_name)()
if not backend.is_available:
Expand All @@ -80,6 +83,7 @@ def test_data_equal(self):
"backend",
"tf_version",
"pt_version",
"jax_version",
"@variables",
# dpmodel only
"software",
Expand Down
Loading