Skip to content

Commit

Permalink
feat(jax): checkpoint I/O (#4236)
Browse files Browse the repository at this point in the history
Implement a JAX checkpoint format. I name it `*.jax` as I don't find
existing conventions.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
- Introduced serialization and deserialization functionalities for JAX
backend models.
- Added support for the `.jax` file suffix in the backend configuration.
- Enhanced attribute handling logic across various classes to ensure
proper processing of non-null values.

- **Bug Fixes**
	- Enhanced cleanup processes in the test suite to improve reliability.

- **Chores**
- Updated dependencies in the project configuration for better JAX
compatibility.
	- Adjusted linting rules to accommodate JAX-related code.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Oct 24, 2024
1 parent c870ccf commit 4d50048
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 10 deletions.
17 changes: 12 additions & 5 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ class JAXBackend(Backend):
name = "JAX"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature(0)
Backend.Feature.IO
# Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
# | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = []
suffixes: ClassVar[list[str]] = [".jax"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down Expand Up @@ -94,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
Expand All @@ -105,4 +108,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
3 changes: 3 additions & 0 deletions deepmd/jax/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.common import (
ArrayAPIVariable,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
Expand All @@ -11,6 +12,8 @@
def base_atomic_model_set_attr(name, value):
if name in {"out_bias", "out_std"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name == "pair_excl" and value is not None:
value = PairExcludeMask(value.ntypes, value.exclude_types)
elif name == "atom_excl" and value is not None:
Expand Down
14 changes: 14 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,17 @@ def __setattr__(self, name: str, value: Any) -> None:
return super().__setattr__(name, value)

return FlaxModule


class ArrayAPIVariable(nnx.Variable):
def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs)

def __array_namespace__(self, *args, **kwargs):
return self.value.__array_namespace__(*args, **kwargs)

def __dlpack__(self, *args, **kwargs):
return self.value.__dlpack__(*args, **kwargs)

def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)
3 changes: 3 additions & 0 deletions deepmd/jax/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP,
)
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand Down Expand Up @@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings", "embeddings_strip"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
"aparam_inv_std",
}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name == "emask":
value = AtomExcludeMask(value.ntypes, value.exclude_types)
elif name == "nets":
Expand Down
5 changes: 5 additions & 0 deletions deepmd/jax/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
return super().__setattr__(name, value)


Expand All @@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
return super().__setattr__(name, value)
97 changes: 97 additions & 0 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from pathlib import (
Path,
)

import orbax.checkpoint as ocp

from deepmd.jax.env import (
jax,
nnx,
)
from deepmd.jax.model.model import (
BaseModel,
get_model,
)


def deserialize_to_file(model_file: str, data: dict) -> None:
"""Deserialize the dictionary to a model file.
Parameters
----------
model_file : str
The model file to be saved.
data : dict
The dictionary to be deserialized.
"""
if model_file.endswith(".jax"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
_, state = nnx.split(model)
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
checkpointer.save(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardSave(state.to_pure_dict()),
model_def_script=ocp.args.JsonSave(model_def_script),
),
)
else:
raise ValueError("JAX backend only supports converting .jax directory")


def serialize_from_file(model_file: str) -> dict:
"""Serialize the model file to a dictionary.
Parameters
----------
model_file : str
The model file to be serialized.
Returns
-------
dict
The serialized model data.
"""
if model_file.endswith(".jax"):
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
data = checkpointer.restore(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardRestore(),
model_def_script=ocp.args.JsonRestore(),
),
)
state = data.state

# convert str "1" to int 1 key
def convert_str_to_int_key(item: dict):
for key, value in item.copy().items():
if isinstance(value, dict):
convert_str_to_int_key(value)
if key.isdigit():
item[int(key)] = item.pop(key)

convert_str_to_int_key(state)

model_def_script = data.model_def_script
abstract_model = get_model(model_def_script)
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(state)
model = nnx.merge(graphdef, abstract_state)
model_dict = model.serialize()
data = {
"backend": "JAX",
"jax_version": jax.__version__,
"model": model_dict,
"model_def_script": model_def_script,
"@variables": {},
}
return data
else:
raise ValueError("JAX backend only supports converting .jax directory")
3 changes: 3 additions & 0 deletions deepmd/jax/utils/type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"econf_tebd"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
if name in {"embedding_net"}:
value = EmbeddingNet.deserialize(value.serialize())
return super().__setattr__(name, value)
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ cu12 = [
]
jax = [
'jax>=0.4.33;python_version>="3.10"',
'flax>=0.8.0;python_version>="3.10"',
'flax>=0.10.0;python_version>="3.10"',
'orbax-checkpoint;python_version>="3.10"',
# The pinning of ml_dtypes may conflict with TF
# 'jax-ai-stack;python_version>="3.10"',
]

[tool.deepmd_build_backend.scripts]
Expand Down Expand Up @@ -402,6 +405,7 @@ banned-module-level-imports = [
# Also ignore `E402` in all `__init__.py` files.
"deepmd/tf/**" = ["TID253"]
"deepmd/pt/**" = ["TID253"]
"deepmd/jax/**" = ["TID253"]
"source/tests/tf/**" = ["TID253"]
"source/tests/pt/**" = ["TID253"]
"source/tests/universal/pt/**" = ["TID253"]
Expand Down
12 changes: 8 additions & 4 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import shutil
import unittest
from pathlib import (
Path,
Expand Down Expand Up @@ -60,15 +61,17 @@ def save_data_to_model(self, model_file: str, data: dict) -> None:
def tearDown(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for ii in Path(".").glob(prefix + ".*"):
if Path(ii).exists():
if Path(ii).is_file():
Path(ii).unlink()
elif Path(ii).is_dir():
shutil.rmtree(ii)

def test_data_equal(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for backend_name in ("tensorflow", "pytorch", "dpmodel"):
for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):
with self.subTest(backend_name=backend_name):
backend = Backend.get_backend(backend_name)()
if not backend.is_available:
if not backend.is_available():
continue
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(prefix + backend.suffixes[0], reference_data)
Expand All @@ -80,6 +83,7 @@ def test_data_equal(self):
"backend",
"tf_version",
"pt_version",
"jax_version",
"@variables",
# dpmodel only
"software",
Expand Down Expand Up @@ -123,7 +127,7 @@ def test_deep_eval(self):
rets = []
for backend_name in ("tensorflow", "pytorch", "dpmodel"):
backend = Backend.get_backend(backend_name)()
if not backend.is_available:
if not backend.is_available():
continue
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(prefix + backend.suffixes[0], reference_data)
Expand Down

0 comments on commit 4d50048

Please sign in to comment.