Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): checkpoint I/O #4236

Merged
merged 18 commits into from
Oct 24, 2024
17 changes: 12 additions & 5 deletions deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,13 @@ class JAXBackend(Backend):
name = "JAX"
"""The formal name of the backend."""
features: ClassVar[Backend.Feature] = (
Backend.Feature(0)
Backend.Feature.IO
# Backend.Feature.ENTRY_POINT
# | Backend.Feature.DEEP_EVAL
# | Backend.Feature.NEIGHBOR_STAT
# | Backend.Feature.IO
)
"""The features of the backend."""
suffixes: ClassVar[list[str]] = []
suffixes: ClassVar[list[str]] = [".jax"]
"""The suffixes of the backend."""

def is_available(self) -> bool:
Expand Down Expand Up @@ -94,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]:
Callable[[str], dict]
The serialize hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.serialization import (
serialize_from_file,
)

return serialize_from_file

@property
def deserialize_hook(self) -> Callable[[str, dict], None]:
Expand All @@ -105,4 +108,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]:
Callable[[str, dict], None]
The deserialize hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.utils.serialization import (
deserialize_to_file,
)

return deserialize_to_file
3 changes: 3 additions & 0 deletions deepmd/jax/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.jax.common import (
ArrayAPIVariable,
to_jax_array,
)
from deepmd.jax.utils.exclude_mask import (
Expand All @@ -11,6 +12,8 @@
def base_atomic_model_set_attr(name, value):
if name in {"out_bias", "out_std"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name == "pair_excl" and value is not None:
value = PairExcludeMask(value.ntypes, value.exclude_types)
elif name == "atom_excl" and value is not None:
Expand Down
14 changes: 14 additions & 0 deletions deepmd/jax/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,17 @@
return super().__setattr__(name, value)

return FlaxModule


class ArrayAPIVariable(nnx.Variable):
def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs)

def __array_namespace__(self, *args, **kwargs):
return self.value.__array_namespace__(*args, **kwargs)

def __dlpack__(self, *args, **kwargs):
return self.value.__dlpack__(*args, **kwargs)

Check warning on line 94 in deepmd/jax/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/common.py#L94

Added line #L94 was not covered by tests

def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs)

Check warning on line 97 in deepmd/jax/common.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/common.py#L97

Added line #L97 was not covered by tests
3 changes: 3 additions & 0 deletions deepmd/jax/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP,
)
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand Down Expand Up @@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"mean", "stddev"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings", "embeddings_strip"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"dstd", "davg"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name in {"embeddings"}:
if value is not None:
value = NetworkCollection.deserialize(value.serialize())
Expand Down
3 changes: 3 additions & 0 deletions deepmd/jax/fitting/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP
from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
"aparam_inv_std",
}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
elif name == "emask":
value = AtomExcludeMask(value.ntypes, value.exclude_types)
elif name == "nets":
Expand Down
5 changes: 5 additions & 0 deletions deepmd/jax/utils/exclude_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP
from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
return super().__setattr__(name, value)


Expand All @@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"type_mask"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
return super().__setattr__(name, value)
97 changes: 97 additions & 0 deletions deepmd/jax/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from pathlib import (
Path,
)

import orbax.checkpoint as ocp

from deepmd.jax.env import (
jax,
nnx,
)
from deepmd.jax.model.model import (
BaseModel,
get_model,
)


def deserialize_to_file(model_file: str, data: dict) -> None:
"""Deserialize the dictionary to a model file.

Parameters
----------
model_file : str
The model file to be saved.
data : dict
The dictionary to be deserialized.
"""
if model_file.endswith(".jax"):
model = BaseModel.deserialize(data["model"])
model_def_script = data["model_def_script"]
_, state = nnx.split(model)
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
checkpointer.save(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardSave(state.to_pure_dict()),
model_def_script=ocp.args.JsonSave(model_def_script),
),
)
else:
raise ValueError("JAX backend only supports converting .jax directory")

Check warning on line 43 in deepmd/jax/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/utils/serialization.py#L43

Added line #L43 was not covered by tests


def serialize_from_file(model_file: str) -> dict:
"""Serialize the model file to a dictionary.

Parameters
----------
model_file : str
The model file to be serialized.

Returns
-------
dict
The serialized model data.
"""
if model_file.endswith(".jax"):
with ocp.Checkpointer(
ocp.CompositeCheckpointHandler("state", "model_def_script")
) as checkpointer:
data = checkpointer.restore(
Path(model_file).absolute(),
ocp.args.Composite(
state=ocp.args.StandardRestore(),
model_def_script=ocp.args.JsonRestore(),
),
)
state = data.state
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Fixed Show fixed Hide fixed

# convert str "1" to int 1 key
def convert_str_to_int_key(item: dict):
for key, value in item.copy().items():
if isinstance(value, dict):
convert_str_to_int_key(value)
if key.isdigit():
item[int(key)] = item.pop(key)

convert_str_to_int_key(state)

model_def_script = data.model_def_script
abstract_model = get_model(model_def_script)
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(state)
model = nnx.merge(graphdef, abstract_state)
model_dict = model.serialize()
data = {
"backend": "JAX",
"jax_version": jax.__version__,
"model": model_dict,
"model_def_script": model_def_script,
"@variables": {},
}
return data
else:
raise ValueError("JAX backend only supports converting .jax directory")

Check warning on line 97 in deepmd/jax/utils/serialization.py

View check run for this annotation

Codecov / codecov/patch

deepmd/jax/utils/serialization.py#L97

Added line #L97 was not covered by tests
3 changes: 3 additions & 0 deletions deepmd/jax/utils/type_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP
from deepmd.jax.common import (
ArrayAPIVariable,
flax_module,
to_jax_array,
)
Expand All @@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP):
def __setattr__(self, name: str, value: Any) -> None:
if name in {"econf_tebd"}:
value = to_jax_array(value)
if value is not None:
value = ArrayAPIVariable(value)
if name in {"embedding_net"}:
value = EmbeddingNet.deserialize(value.serialize())
return super().__setattr__(name, value)
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ cu12 = [
]
jax = [
'jax>=0.4.33;python_version>="3.10"',
'flax>=0.8.0;python_version>="3.10"',
'flax>=0.10.0;python_version>="3.10"',
'orbax-checkpoint;python_version>="3.10"',
# The pinning of ml_dtypes may conflict with TF
# 'jax-ai-stack;python_version>="3.10"',
]

[tool.deepmd_build_backend.scripts]
Expand Down Expand Up @@ -402,6 +405,7 @@ banned-module-level-imports = [
# Also ignore `E402` in all `__init__.py` files.
"deepmd/tf/**" = ["TID253"]
"deepmd/pt/**" = ["TID253"]
"deepmd/jax/**" = ["TID253"]
"source/tests/tf/**" = ["TID253"]
"source/tests/pt/**" = ["TID253"]
"source/tests/universal/pt/**" = ["TID253"]
Expand Down
12 changes: 8 additions & 4 deletions source/tests/consistent/io/test_io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import shutil
import unittest
from pathlib import (
Path,
Expand Down Expand Up @@ -60,15 +61,17 @@ def save_data_to_model(self, model_file: str, data: dict) -> None:
def tearDown(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for ii in Path(".").glob(prefix + ".*"):
if Path(ii).exists():
if Path(ii).is_file():
Path(ii).unlink()
elif Path(ii).is_dir():
shutil.rmtree(ii)

def test_data_equal(self):
prefix = "test_consistent_io_" + self.__class__.__name__.lower()
for backend_name in ("tensorflow", "pytorch", "dpmodel"):
for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"):
with self.subTest(backend_name=backend_name):
backend = Backend.get_backend(backend_name)()
if not backend.is_available:
if not backend.is_available():
continue
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(prefix + backend.suffixes[0], reference_data)
Expand All @@ -80,6 +83,7 @@ def test_data_equal(self):
"backend",
"tf_version",
"pt_version",
"jax_version",
"@variables",
# dpmodel only
"software",
Expand Down Expand Up @@ -123,7 +127,7 @@ def test_deep_eval(self):
rets = []
for backend_name in ("tensorflow", "pytorch", "dpmodel"):
backend = Backend.get_backend(backend_name)()
if not backend.is_available:
if not backend.is_available():
continue
reference_data = copy.deepcopy(self.data)
self.save_data_to_model(prefix + backend.suffixes[0], reference_data)
Expand Down