From 4d5004830384577922664ca70431aef1c818e799 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 24 Oct 2024 04:43:14 -0400 Subject: [PATCH] feat(jax): checkpoint I/O (#4236) Implement a JAX checkpoint format. I name it `*.jax` as I don't find existing conventions. ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced serialization and deserialization functionalities for JAX backend models. - Added support for the `.jax` file suffix in the backend configuration. - Enhanced attribute handling logic across various classes to ensure proper processing of non-null values. - **Bug Fixes** - Enhanced cleanup processes in the test suite to improve reliability. - **Chores** - Updated dependencies in the project configuration for better JAX compatibility. - Adjusted linting rules to accommodate JAX-related code. --------- Signed-off-by: Jinzhe Zeng Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- deepmd/backend/jax.py | 17 +++- deepmd/jax/atomic_model/base_atomic_model.py | 3 + deepmd/jax/common.py | 14 +++ deepmd/jax/descriptor/dpa1.py | 3 + deepmd/jax/descriptor/se_e2_a.py | 3 + deepmd/jax/fitting/fitting.py | 3 + deepmd/jax/utils/exclude_mask.py | 5 + deepmd/jax/utils/serialization.py | 97 ++++++++++++++++++++ deepmd/jax/utils/type_embed.py | 3 + pyproject.toml | 6 +- source/tests/consistent/io/test_io.py | 12 ++- 11 files changed, 156 insertions(+), 10 deletions(-) create mode 100644 deepmd/jax/utils/serialization.py diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index db92d6bed1..bb2fba5a7c 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -32,14 +32,13 @@ class JAXBackend(Backend): name = "JAX" """The formal name of the backend.""" features: ClassVar[Backend.Feature] = ( - Backend.Feature(0) + Backend.Feature.IO # Backend.Feature.ENTRY_POINT # | Backend.Feature.DEEP_EVAL # | Backend.Feature.NEIGHBOR_STAT - # | Backend.Feature.IO ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [] + suffixes: ClassVar[list[str]] = [".jax"] """The suffixes of the backend.""" def is_available(self) -> bool: @@ -94,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - raise NotImplementedError + from deepmd.jax.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file @property def deserialize_hook(self) -> Callable[[str, dict], None]: @@ -105,4 +108,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - raise NotImplementedError + from deepmd.jax.utils.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py index 90920879c2..ffd58daf5e 100644 --- a/deepmd/jax/atomic_model/base_atomic_model.py +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.jax.common import ( + ArrayAPIVariable, to_jax_array, ) from deepmd.jax.utils.exclude_mask import ( @@ -11,6 +12,8 @@ def base_atomic_model_set_attr(name, value): if name in {"out_bias", "out_std"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name == "pair_excl" and value is not None: value = PairExcludeMask(value.ntypes, value.exclude_types) elif name == "atom_excl" and value is not None: diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 9c144a41d1..f372e97eb5 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -81,3 +81,17 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) return FlaxModule + + +class ArrayAPIVariable(nnx.Variable): + def __array__(self, *args, **kwargs): + return self.value.__array__(*args, **kwargs) + + def __array_namespace__(self, *args, **kwargs): + return self.value.__array_namespace__(*args, **kwargs) + + def __dlpack__(self, *args, **kwargs): + return self.value.__dlpack__(*args, **kwargs) + + def __dlpack_device__(self, *args, **kwargs): + return self.value.__dlpack_device__(*args, **kwargs) diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index 0528e4bb93..fef9bd5448 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -13,6 +13,7 @@ NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, ) from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mean", "stddev"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name in {"embeddings", "embeddings_strip"}: if value is not None: value = NetworkCollection.deserialize(value.serialize()) diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py index d1a6e9a8d9..31c147ad9d 100644 --- a/deepmd/jax/descriptor/se_e2_a.py +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name in {"embeddings"}: if value is not None: value = NetworkCollection.deserialize(value.serialize()) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index f979db4d41..cef1f667b3 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -6,6 +6,7 @@ from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any: "aparam_inv_std", }: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name == "emask": value = AtomExcludeMask(value.ntypes, value.exclude_types) elif name == "nets": diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py index a6cf210f94..18d13d9400 100644 --- a/deepmd/jax/utils/exclude_mask.py +++ b/deepmd/jax/utils/exclude_mask.py @@ -6,6 +6,7 @@ from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) return super().__setattr__(name, value) @@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) return super().__setattr__(name, value) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py new file mode 100644 index 0000000000..43070f8a07 --- /dev/null +++ b/deepmd/jax/utils/serialization.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from pathlib import ( + Path, +) + +import orbax.checkpoint as ocp + +from deepmd.jax.env import ( + jax, + nnx, +) +from deepmd.jax.model.model import ( + BaseModel, + get_model, +) + + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize the dictionary to a model file. + + Parameters + ---------- + model_file : str + The model file to be saved. + data : dict + The dictionary to be deserialized. + """ + if model_file.endswith(".jax"): + model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] + _, state = nnx.split(model) + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + checkpointer.save( + Path(model_file).absolute(), + ocp.args.Composite( + state=ocp.args.StandardSave(state.to_pure_dict()), + model_def_script=ocp.args.JsonSave(model_def_script), + ), + ) + else: + raise ValueError("JAX backend only supports converting .jax directory") + + +def serialize_from_file(model_file: str) -> dict: + """Serialize the model file to a dictionary. + + Parameters + ---------- + model_file : str + The model file to be serialized. + + Returns + ------- + dict + The serialized model data. + """ + if model_file.endswith(".jax"): + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + data = checkpointer.restore( + Path(model_file).absolute(), + ocp.args.Composite( + state=ocp.args.StandardRestore(), + model_def_script=ocp.args.JsonRestore(), + ), + ) + state = data.state + + # convert str "1" to int 1 key + def convert_str_to_int_key(item: dict): + for key, value in item.copy().items(): + if isinstance(value, dict): + convert_str_to_int_key(value) + if key.isdigit(): + item[int(key)] = item.pop(key) + + convert_str_to_int_key(state) + + model_def_script = data.model_def_script + abstract_model = get_model(model_def_script) + graphdef, abstract_state = nnx.split(abstract_model) + abstract_state.replace_by_pure_dict(state) + model = nnx.merge(graphdef, abstract_state) + model_dict = model.serialize() + data = { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model_dict, + "model_def_script": model_def_script, + "@variables": {}, + } + return data + else: + raise ValueError("JAX backend only supports converting .jax directory") diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index 3143460244..30cd9f45a9 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"econf_tebd"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) if name in {"embedding_net"}: value = EmbeddingNet.deserialize(value.serialize()) return super().__setattr__(name, value) diff --git a/pyproject.toml b/pyproject.toml index 4dbff24f13..3bd18d42a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,10 @@ cu12 = [ ] jax = [ 'jax>=0.4.33;python_version>="3.10"', - 'flax>=0.8.0;python_version>="3.10"', + 'flax>=0.10.0;python_version>="3.10"', + 'orbax-checkpoint;python_version>="3.10"', + # The pinning of ml_dtypes may conflict with TF + # 'jax-ai-stack;python_version>="3.10"', ] [tool.deepmd_build_backend.scripts] @@ -402,6 +405,7 @@ banned-module-level-imports = [ # Also ignore `E402` in all `__init__.py` files. "deepmd/tf/**" = ["TID253"] "deepmd/pt/**" = ["TID253"] +"deepmd/jax/**" = ["TID253"] "source/tests/tf/**" = ["TID253"] "source/tests/pt/**" = ["TID253"] "source/tests/universal/pt/**" = ["TID253"] diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 71e4002128..feafde234d 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy +import shutil import unittest from pathlib import ( Path, @@ -60,15 +61,17 @@ def save_data_to_model(self, model_file: str, data: dict) -> None: def tearDown(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for ii in Path(".").glob(prefix + ".*"): - if Path(ii).exists(): + if Path(ii).is_file(): Path(ii).unlink() + elif Path(ii).is_dir(): + shutil.rmtree(ii) def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() - for backend_name in ("tensorflow", "pytorch", "dpmodel"): + for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): with self.subTest(backend_name=backend_name): backend = Backend.get_backend(backend_name)() - if not backend.is_available: + if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data) @@ -80,6 +83,7 @@ def test_data_equal(self): "backend", "tf_version", "pt_version", + "jax_version", "@variables", # dpmodel only "software", @@ -123,7 +127,7 @@ def test_deep_eval(self): rets = [] for backend_name in ("tensorflow", "pytorch", "dpmodel"): backend = Backend.get_backend(backend_name)() - if not backend.is_available: + if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data)