From f6b30bf2acaef3952883344be579d6e05417c2f3 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 16 Oct 2024 17:25:35 -0400 Subject: [PATCH 01/17] feat(jax): energy model (no grad support) Signed-off-by: Jinzhe Zeng --- .../dpmodel/atomic_model/base_atomic_model.py | 16 +++-- .../dpmodel/atomic_model/dp_atomic_model.py | 10 +++- deepmd/dpmodel/model/make_model.py | 20 ++++--- deepmd/dpmodel/model/transform_output.py | 4 +- deepmd/jax/atomic_model/__init__.py | 1 + deepmd/jax/atomic_model/base_atomic_model.py | 20 +++++++ deepmd/jax/atomic_model/dp_atomic_model.py | 26 +++++++++ deepmd/jax/descriptor/__init__.py | 11 ++++ deepmd/jax/descriptor/base_descriptor.py | 9 +++ deepmd/jax/descriptor/dpa1.py | 5 ++ deepmd/jax/descriptor/se_e2_a.py | 5 ++ deepmd/jax/fitting/__init__.py | 9 +++ deepmd/jax/fitting/base_fitting.py | 9 +++ deepmd/jax/fitting/fitting.py | 5 ++ deepmd/jax/model/__init__.py | 1 + deepmd/jax/model/base_model.py | 6 ++ deepmd/jax/model/ener_model.py | 20 +++++++ deepmd/jax/model/model.py | 58 +++++++++++++++++++ source/tests/consistent/model/common.py | 23 ++++++++ source/tests/consistent/model/test_ener.py | 26 +++++++++ 20 files changed, 266 insertions(+), 18 deletions(-) create mode 100644 deepmd/jax/atomic_model/__init__.py create mode 100644 deepmd/jax/atomic_model/base_atomic_model.py create mode 100644 deepmd/jax/atomic_model/dp_atomic_model.py create mode 100644 deepmd/jax/descriptor/base_descriptor.py create mode 100644 deepmd/jax/fitting/base_fitting.py create mode 100644 deepmd/jax/model/__init__.py create mode 100644 deepmd/jax/model/base_model.py create mode 100644 deepmd/jax/model/ener_model.py create mode 100644 deepmd/jax/model/model.py diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index c29a76b3f1..83503dacdd 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,13 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy +import math from typing import ( Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( NativeOP, + to_numpy_array, ) from deepmd.dpmodel.output_def import ( FittingOutputDef, @@ -172,17 +175,18 @@ def forward_common_atomic( ret_dict["mask"][ff,ii] == 0 indicating the ii-th atom of the ff-th frame is virtual. """ + xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) _, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] if self.pair_excl is not None: pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) # exclude neighbors in the nlist - nlist = np.where(pair_mask == 1, nlist, -1) + nlist = xp.where(pair_mask == 1, nlist, -1) ext_atom_mask = self.make_atom_mask(extended_atype) ret_dict = self.forward_atomic( extended_coord, - np.where(ext_atom_mask, extended_atype, 0), + xp.where(ext_atom_mask, extended_atype, 0), nlist, mapping=mapping, fparam=fparam, @@ -191,13 +195,13 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc].astype(np.int32) + atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32) if self.atom_excl is not None: atom_mask *= self.atom_excl.build_type_exclude_mask(atype) for kk in ret_dict.keys(): out_shape = ret_dict[kk].shape - out_shape2 = np.prod(out_shape[2:]) + out_shape2 = math.prod(out_shape[2:]) ret_dict[kk] = ( ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2]) * atom_mask[:, :, None] @@ -232,8 +236,8 @@ def serialize(self) -> dict: "rcond": self.rcond, "preset_out_bias": self.preset_out_bias, "@variables": { - "out_bias": self.out_bias, - "out_std": self.out_std, + "out_bias": to_numpy_array(self.out_bias), + "out_std": to_numpy_array(self.out_std), }, } diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 7e576eb484..fe049021fe 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -169,14 +169,20 @@ def serialize(self) -> dict: ) return dd + # for subclass overriden + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) check_version_compatibility(data.pop("@version", 1), 2, 2) data.pop("@class") data.pop("type") - descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) - fitting_obj = BaseFitting.deserialize(data.pop("fitting")) + descriptor_obj = cls.base_descriptor_cls.deserialize(data.pop("descriptor")) + fitting_obj = cls.base_fitting_cls.deserialize(data.pop("fitting")) data["descriptor"] = descriptor_obj data["fitting"] = fitting_obj obj = super().deserialize(data) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 8cdb7e1f25..8894b3efb9 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -3,6 +3,7 @@ Optional, ) +import array_api_compat import numpy as np from deepmd.dpmodel.atomic_model.base_atomic_model import ( @@ -366,6 +367,7 @@ def _format_nlist( nnei: int, extra_nlist_sort: bool = False, ): + xp = array_api_compat.array_namespace(extended_coord, nlist) n_nf, n_nloc, n_nnei = nlist.shape extended_coord = extended_coord.reshape([n_nf, -1, 3]) nall = extended_coord.shape[1] @@ -373,10 +375,10 @@ def _format_nlist( if n_nnei < nnei: # make a copy before revise - ret = np.concatenate( + ret = xp.concat( [ nlist, - -1 * np.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), + -1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), ], axis=-1, ) @@ -385,16 +387,16 @@ def _format_nlist( n_nf, n_nloc, n_nnei = nlist.shape # make a copy before revise m_real_nei = nlist >= 0 - ret = np.where(m_real_nei, nlist, 0) + ret = xp.where(m_real_nei, nlist, 0) coord0 = extended_coord[:, :n_nloc, :] index = ret.reshape(n_nf, n_nloc * n_nnei, 1).repeat(3, axis=2) - coord1 = np.take_along_axis(extended_coord, index, axis=1) + coord1 = xp.take_along_axis(extended_coord, index, axis=1) coord1 = coord1.reshape(n_nf, n_nloc, n_nnei, 3) - rr = np.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) - rr = np.where(m_real_nei, rr, float("inf")) - rr, ret_mapping = np.sort(rr, axis=-1), np.argsort(rr, axis=-1) - ret = np.take_along_axis(ret, ret_mapping, axis=2) - ret = np.where(rr > rcut, -1, ret) + rr = xp.linalg.norm(coord0[:, :, None, :] - coord1, axis=-1) + rr = xp.where(m_real_nei, rr, float("inf")) + rr, ret_mapping = xp.sort(rr, axis=-1), xp.argsort(rr, axis=-1) + ret = xp.take_along_axis(ret, ret_mapping, axis=2) + ret = xp.where(rr > rcut, -1, ret) ret = ret[..., :nnei] # not extra_nlist_sort and n_nnei <= nnei: elif n_nnei == nnei: diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index 43c275b1be..928c33f3bd 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import array_api_compat import numpy as np from deepmd.dpmodel.common import ( @@ -23,6 +24,7 @@ def fit_output_to_model_output( the model output. """ + xp = array_api_compat.get_namespace(coord_ext) model_ret = dict(fit_ret.items()) for kk, vv in fit_ret.items(): vdef = fit_output_def[kk] @@ -31,7 +33,7 @@ def fit_output_to_model_output( if vdef.reducible: kk_redu = get_reduce_name(kk) # cast to energy prec brefore reduction - model_ret[kk_redu] = np.sum( + model_ret[kk_redu] = xp.sum( vv.astype(GLOBAL_ENER_FLOAT_PRECISION), axis=atom_axis ) if vdef.r_differentiable: diff --git a/deepmd/jax/atomic_model/__init__.py b/deepmd/jax/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py new file mode 100644 index 0000000000..e4a349a78b --- /dev/null +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.common import ( + to_jax_array, +) +from deepmd.jax.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + + +def base_atomic_model_set_attr(name, value): + if name in {"out_bias", "out_std"}: + value = to_jax_array(value) + elif name == "pair_excl": + if value is not None: + value = PairExcludeMask(value.ntypes, value.exclude_types) + elif name == "atom_excl": + if value is not None: + value = AtomExcludeMask(value.ntypes, value.exclude_types) + return value diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py new file mode 100644 index 0000000000..14b6ee2e7b --- /dev/null +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -0,0 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP +from deepmd.jax.atomic_model.base_atomic_model import ( + base_atomic_model_set_attr, +) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) + + +class DPAtomicModel(DPAtomicModelDP): + base_descriptor_cls = BaseDescriptor + """The base descriptor class.""" + base_fitting_cls = BaseFitting + """The base fitting class.""" + + def __setattr__(self, name: str, value: Any) -> None: + value = base_atomic_model_set_attr(name, value) + return super().__setattr__(name, value) diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 6ceb116d85..ed59493268 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -1 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.jax.descriptor.se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", + "DescrptDPA1", +] diff --git a/deepmd/jax/descriptor/base_descriptor.py b/deepmd/jax/descriptor/base_descriptor.py new file mode 100644 index 0000000000..7dec3cd6d4 --- /dev/null +++ b/deepmd/jax/descriptor/base_descriptor.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.descriptor.make_base_descriptor import ( + make_base_descriptor, +) +from deepmd.jax.env import ( + jnp, +) + +BaseDescriptor = make_base_descriptor(jnp.ndarray) diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index a9b0404970..0528e4bb93 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -16,6 +16,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -76,6 +79,8 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseDescriptor.register("dpa1") +@BaseDescriptor.register("se_atten") @flax_module class DescrptDPA1(DescrptDPA1DP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py index a60a4e9af1..d1a6e9a8d9 100644 --- a/deepmd/jax/descriptor/se_e2_a.py +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -8,6 +8,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) from deepmd.jax.utils.exclude_mask import ( PairExcludeMask, ) @@ -16,6 +19,8 @@ ) +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") @flax_module class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/fitting/__init__.py b/deepmd/jax/fitting/__init__.py index 6ceb116d85..e72314dcab 100644 --- a/deepmd/jax/fitting/__init__.py +++ b/deepmd/jax/fitting/__init__.py @@ -1 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.fitting.fitting import ( + DOSFittingNet, + EnergyFittingNet, +) + +__all__ = [ + "EnergyFittingNet", + "DOSFittingNet", +] diff --git a/deepmd/jax/fitting/base_fitting.py b/deepmd/jax/fitting/base_fitting.py new file mode 100644 index 0000000000..fd9f3a416d --- /dev/null +++ b/deepmd/jax/fitting/base_fitting.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.fitting.make_base_fitting import ( + make_base_fitting, +) +from deepmd.jax.env import ( + jnp, +) + +BaseFitting = make_base_fitting(jnp.ndarray) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index 284213c70a..f979db4d41 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -9,6 +9,9 @@ flax_module, to_jax_array, ) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) from deepmd.jax.utils.exclude_mask import ( AtomExcludeMask, ) @@ -33,6 +36,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any: return value +@BaseFitting.register("ener") @flax_module class EnergyFittingNet(EnergyFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: @@ -40,6 +44,7 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) +@BaseFitting.register("dos") @flax_module class DOSFittingNet(DOSFittingNetDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/jax/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/jax/model/base_model.py b/deepmd/jax/model/base_model.py new file mode 100644 index 0000000000..fee4855da3 --- /dev/null +++ b/deepmd/jax/model/base_model.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.model.base_model import ( + make_base_model, +) + +BaseModel = make_base_model() diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py new file mode 100644 index 0000000000..3cf60fdbde --- /dev/null +++ b/deepmd/jax/model/ener_model.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.model import EnergyModel as EnergyModelDP +from deepmd.jax.atomic_model.dp_atomic_model import ( + DPAtomicModel, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +@BaseModel.register("ener") +class EnergyModel(EnergyModelDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "atomic_model": + value = DPAtomicModel.deserialize(value.serialize()) + return super().__setattr__(name, value) diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py new file mode 100644 index 0000000000..3b98ebe25f --- /dev/null +++ b/deepmd/jax/model/model.py @@ -0,0 +1,58 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.jax.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.jax.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.jax.model.base_model import ( + BaseModel, +) + + +def get_standard_model(data: dict): + """Get a Model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + descriptor_type = data["descriptor"].pop("type") + data["descriptor"]["type_map"] = data["type_map"] + fitting_type = data["fitting_net"].pop("type") + data["fitting_net"]["type_map"] = data["type_map"] + descriptor = BaseDescriptor.get_class_by_type(descriptor_type)( + **data["descriptor"], + ) + fitting = BaseFitting.get_class_by_type(fitting_type)( + ntypes=descriptor.get_ntypes(), + dim_descrpt=descriptor.get_dim_out(), + mixed_types=descriptor.mixed_types(), + **data["fitting_net"], + ) + return BaseModel.get_class_by_type(fitting_type)( + descriptor=descriptor, + fitting=fitting, + type_map=data["type_map"], + atom_exclude_types=data.get("atom_exclude_types", []), + pair_exclude_types=data.get("pair_exclude_types", []), + ) + + +def get_model(data: dict): + """Get a model from a dictionary. + + Parameters + ---------- + data : dict + The data to construct the model. + """ + model_type = data.get("type", "standard") + if model_type == "standard": + if "spin" in data: + raise NotImplementedError("Spin model is not implemented yet.") + else: + return get_standard_model(data) + else: + return BaseModel.get_class_by_type(model_type).get_model(data) diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 294edec1d6..4112e09cff 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -6,8 +6,12 @@ from deepmd.common import ( make_default_mesh, ) +from deepmd.dpmodel.common import ( + to_numpy_array, +) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, ) @@ -20,6 +24,11 @@ GLOBAL_TF_FLOAT_PRECISION, tf, ) +if INSTALLED_JAX: + from deepmd.jax.common import to_jax_array as numpy_to_jax + from deepmd.jax.env import ( + jnp, + ) class ModelTest: @@ -62,3 +71,17 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: box=numpy_to_torch(box), ).items() } + + def eval_jax_model(self, jax_obj: Any, natoms, coords, atype, box) -> Any: + def assert_jax_array(arr): + assert isinstance(arr, jnp.ndarray) or arr is None + return arr + + return { + kk: to_numpy_array(assert_jax_array(vv)) + for kk, vv in jax_obj( + numpy_to_jax(coords), + numpy_to_jax(atype), + box=numpy_to_jax(box), + ).items() + } diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 692e1287dc..78a2aac703 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -13,6 +13,7 @@ ) from ..common import ( + INSTALLED_JAX, INSTALLED_PT, INSTALLED_TF, CommonTest, @@ -36,6 +37,12 @@ model_args, ) +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None + @parameterized( ( @@ -84,14 +91,20 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT + jax_class = EnergyModelJAX args = model_args() + @property def skip_tf(self): return ( self.data["pair_exclude_types"] != [] or self.data["atom_exclude_types"] != [] ) + @property + def skip_jax(self): + return not INSTALLED_JAX + def pass_data_to_cls(self, cls, data) -> Any: """Pass data to the class.""" data = data.copy() @@ -99,6 +112,8 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) return cls(**data, **self.addtional_data) def setUp(self): @@ -168,6 +183,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: # shape not matched. ravel... if backend is self.RefBackend.DP: @@ -176,4 +200,6 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret["energy"].ravel(), ret["atom_energy"].ravel()) elif backend is self.RefBackend.TF: return (ret[0].ravel(), ret[1].ravel()) + elif backend is self.RefBackend.JAX: + return (ret["energy_redu"].ravel(), ret["energy"].ravel()) raise ValueError(f"Unknown backend: {backend}") From 3e13a338f4f3b14ce665fa1da9a07fa43a085a31 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 16 Oct 2024 17:52:00 -0400 Subject: [PATCH 02/17] address comments Signed-off-by: Jinzhe Zeng --- deepmd/jax/atomic_model/base_atomic_model.py | 10 ++++------ deepmd/jax/model/model.py | 5 +++++ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py index e4a349a78b..90920879c2 100644 --- a/deepmd/jax/atomic_model/base_atomic_model.py +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -11,10 +11,8 @@ def base_atomic_model_set_attr(name, value): if name in {"out_bias", "out_std"}: value = to_jax_array(value) - elif name == "pair_excl": - if value is not None: - value = PairExcludeMask(value.ntypes, value.exclude_types) - elif name == "atom_excl": - if value is not None: - value = AtomExcludeMask(value.ntypes, value.exclude_types) + elif name == "pair_excl" and value is not None: + value = PairExcludeMask(value.ntypes, value.exclude_types) + elif name == "atom_excl" and value is not None: + value = AtomExcludeMask(value.ntypes, value.exclude_types) return value diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py index 3b98ebe25f..7fa3efda6e 100644 --- a/deepmd/jax/model/model.py +++ b/deepmd/jax/model/model.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from copy import ( + deepcopy, +) + from deepmd.jax.descriptor.base_descriptor import ( BaseDescriptor, ) @@ -18,6 +22,7 @@ def get_standard_model(data: dict): data : dict The data to construct the model. """ + data = deepcopy(data) descriptor_type = data["descriptor"].pop("type") data["descriptor"]["type_map"] = data["type_map"] fitting_type = data["fitting_net"].pop("type") From fa9a0002ee7376cf41c36898058cb3319c56b03e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 17:10:14 -0400 Subject: [PATCH 03/17] export EnergyModel Signed-off-by: Jinzhe Zeng --- deepmd/jax/model/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/deepmd/jax/model/__init__.py b/deepmd/jax/model/__init__.py index 6ceb116d85..05a60c4ffe 100644 --- a/deepmd/jax/model/__init__.py +++ b/deepmd/jax/model/__init__.py @@ -1 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .ener_model import ( + EnergyModel, +) + +__all__ = ["EnergyModel"] From d3b5ce3e799921019f378e9c0c30063f8444117a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 17:21:25 -0400 Subject: [PATCH 04/17] jax2tf Signed-off-by: Jinzhe Zeng --- deepmd/backend/jax.py | 11 ++++--- deepmd/jax/env.py | 4 +++ deepmd/jax/utils/serialization.py | 55 +++++++++++++++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 deepmd/jax/utils/serialization.py diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index db92d6bed1..07bacdb937 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -32,14 +32,13 @@ class JAXBackend(Backend): name = "JAX" """The formal name of the backend.""" features: ClassVar[Backend.Feature] = ( - Backend.Feature(0) + Backend.Feature.IO # Backend.Feature.ENTRY_POINT # | Backend.Feature.DEEP_EVAL # | Backend.Feature.NEIGHBOR_STAT - # | Backend.Feature.IO ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [] + suffixes: ClassVar[list[str]] = [".saved_model"] """The suffixes of the backend.""" def is_available(self) -> bool: @@ -105,4 +104,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - raise NotImplementedError + from deepmd.jax.utils.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 5a5a7f6bf0..2d1c1454ed 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -8,6 +8,9 @@ from flax import ( nnx, ) +from jax.experimental import ( + jax2tf, +) jax.config.update("jax_enable_x64", True) @@ -15,4 +18,5 @@ "jax", "jnp", "nnx", + "jax2tf", ] diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py new file mode 100644 index 0000000000..cdd1063fb9 --- /dev/null +++ b/deepmd/jax/utils/serialization.py @@ -0,0 +1,55 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tensorflow as tf + +from deepmd.jax.env import ( + jax2tf, +) +from deepmd.jax.model.model import ( + BaseModel, +) + + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize the dictionary to a model file. + + Parameters + ---------- + model_file : str + The model file to be saved. + data : dict + The dictionary to be deserialized. + """ + if model_file.endswith(".saved_model"): + model = BaseModel.deserialize(data["model"]) + model_def_script = data.get("model_def_script", "{}") + my_model = tf.Module() + my_model.f = tf.function( + jax2tf.convert( + model, + polymorphic_shapes=[ + "(b, n, 3)", + "(b, n)", + "(b, 3, 3)", + "(b, f)", + "(b, a)", + "()", + ], + ), + autograph=False, + input_signature=[ + tf.TensorSpec([None, None, 3], tf.float64), + tf.TensorSpec([None, None], tf.int64), + tf.TensorSpec([None, 3, 3], tf.float64), + tf.TensorSpec([None, None], tf.float64), + tf.TensorSpec([None, None], tf.float64), + tf.TensorSpec([], tf.bool), + ], + ) + my_model.model_def_script = model_def_script + tf.saved_model.save( + my_model, + model_file, + options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), + ) + else: + raise ValueError("JAX backend only supports converting .pth file") diff --git a/pyproject.toml b/pyproject.toml index b13dceeb07..4a99cfaf8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -402,6 +402,7 @@ banned-module-level-imports = [ # Also ignore `E402` in all `__init__.py` files. "deepmd/tf/**" = ["TID253"] "deepmd/pt/**" = ["TID253"] +"deepmd/jax/**" = ["TID253"] "source/tests/tf/**" = ["TID253"] "source/tests/pt/**" = ["TID253"] "source/tests/universal/pt/**" = ["TID253"] From 91ec0847164b299c891130da06a2136a05da736d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 17:42:06 -0400 Subject: [PATCH 05/17] apply flax_module to energy model Signed-off-by: Jinzhe Zeng --- deepmd/jax/model/ener_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepmd/jax/model/ener_model.py b/deepmd/jax/model/ener_model.py index 3cf60fdbde..79c5a29e88 100644 --- a/deepmd/jax/model/ener_model.py +++ b/deepmd/jax/model/ener_model.py @@ -7,12 +7,16 @@ from deepmd.jax.atomic_model.dp_atomic_model import ( DPAtomicModel, ) +from deepmd.jax.common import ( + flax_module, +) from deepmd.jax.model.base_model import ( BaseModel, ) @BaseModel.register("ener") +@flax_module class EnergyModel(EnergyModelDP): def __setattr__(self, name: str, value: Any) -> None: if name == "atomic_model": From 14fab30afb055ed2b5b7bf5c9247d0dfaaaf0e7b Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 17:43:04 -0400 Subject: [PATCH 06/17] checkpoint Signed-off-by: Jinzhe Zeng --- deepmd/backend/jax.py | 2 +- deepmd/jax/utils/serialization.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index 07bacdb937..6ffea0aacd 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): # | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".saved_model"] + suffixes: ClassVar[list[str]] = [".saved_model", ".jax"] """The suffixes of the backend.""" def is_available(self) -> bool: diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index cdd1063fb9..b1844f09bb 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -3,6 +3,7 @@ from deepmd.jax.env import ( jax2tf, + nnx, ) from deepmd.jax.model.model import ( BaseModel, @@ -51,5 +52,9 @@ def deserialize_to_file(model_file: str, data: dict) -> None: model_file, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), ) + elif model_file.endswith(".jax"): + model = BaseModel.deserialize(data["model"]) + state = nnx.state(model) + nnx.display(state) else: raise ValueError("JAX backend only supports converting .pth file") From 390ace5d281ac9a1aed74eef800c6713a728709c Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 17:52:22 -0400 Subject: [PATCH 07/17] fix atomic model Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/model/make_model.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 8894b3efb9..dc90f10da7 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -76,7 +76,8 @@ def __init__( else: self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) self.precision_dict = PRECISION_DICT - self.reverse_precision_dict = RESERVED_PRECISON_DICT + # not supported by flax + # self.reverse_precision_dict = RESERVED_PRECISON_DICT self.global_np_float_precision = GLOBAL_NP_FLOAT_PRECISION self.global_ener_float_precision = GLOBAL_ENER_FLOAT_PRECISION @@ -254,9 +255,7 @@ def input_type_cast( str, ]: """Cast the input data to global float type.""" - input_prec = self.reverse_precision_dict[ - self.precision_dict[coord.dtype.name] - ] + input_prec = RESERVED_PRECISON_DICT[self.precision_dict[coord.dtype.name]] ### ### type checking would not pass jit, convert to coord prec anyway ### @@ -265,10 +264,7 @@ def input_type_cast( for vv in [box, fparam, aparam] ] box, fparam, aparam = _lst - if ( - input_prec - == self.reverse_precision_dict[self.global_np_float_precision] - ): + if input_prec == RESERVED_PRECISON_DICT[self.global_np_float_precision]: return coord, box, fparam, aparam, input_prec else: pp = self.global_np_float_precision @@ -287,8 +283,7 @@ def output_type_cast( ) -> dict[str, np.ndarray]: """Convert the model output to the input prec.""" do_cast = ( - input_prec - != self.reverse_precision_dict[self.global_np_float_precision] + input_prec != RESERVED_PRECISON_DICT[self.global_np_float_precision] ) pp = self.precision_dict[input_prec] odef = self.model_output_def() From a4cd6277d81b5ffdc8508934499292b10808c4ed Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 17:52:35 -0400 Subject: [PATCH 08/17] apply flax_module Signed-off-by: Jinzhe Zeng --- deepmd/jax/atomic_model/dp_atomic_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/deepmd/jax/atomic_model/dp_atomic_model.py b/deepmd/jax/atomic_model/dp_atomic_model.py index 14b6ee2e7b..077209e29a 100644 --- a/deepmd/jax/atomic_model/dp_atomic_model.py +++ b/deepmd/jax/atomic_model/dp_atomic_model.py @@ -7,6 +7,9 @@ from deepmd.jax.atomic_model.base_atomic_model import ( base_atomic_model_set_attr, ) +from deepmd.jax.common import ( + flax_module, +) from deepmd.jax.descriptor.base_descriptor import ( BaseDescriptor, ) @@ -15,6 +18,7 @@ ) +@flax_module class DPAtomicModel(DPAtomicModelDP): base_descriptor_cls = BaseDescriptor """The base descriptor class.""" From 71a4b55138c50da3d5848f54eddf33a0f8ba2b32 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 17:53:08 -0400 Subject: [PATCH 09/17] checkpoint Signed-off-by: Jinzhe Zeng --- deepmd/jax/utils/serialization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index b1844f09bb..ca915d61e3 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -8,6 +8,9 @@ from deepmd.jax.model.model import ( BaseModel, ) +from deepmd.jax.utils.network import ( + ArrayAPIParam, +) def deserialize_to_file(model_file: str, data: dict) -> None: @@ -54,7 +57,7 @@ def deserialize_to_file(model_file: str, data: dict) -> None: ) elif model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) - state = nnx.state(model) + state = nnx.state(model, ArrayAPIParam) nnx.display(state) else: raise ValueError("JAX backend only supports converting .pth file") From 5024f70fdc94d7bd72fe233ca3dbc8e741e15315 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 18:35:03 -0400 Subject: [PATCH 10/17] done Signed-off-by: Jinzhe Zeng --- deepmd/backend/jax.py | 8 ++- deepmd/jax/utils/serialization.py | 97 +++++++++++++++++---------- source/tests/consistent/io/test_io.py | 8 ++- 3 files changed, 72 insertions(+), 41 deletions(-) diff --git a/deepmd/backend/jax.py b/deepmd/backend/jax.py index 6ffea0aacd..bb2fba5a7c 100644 --- a/deepmd/backend/jax.py +++ b/deepmd/backend/jax.py @@ -38,7 +38,7 @@ class JAXBackend(Backend): # | Backend.Feature.NEIGHBOR_STAT ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".saved_model", ".jax"] + suffixes: ClassVar[list[str]] = [".jax"] """The suffixes of the backend.""" def is_available(self) -> bool: @@ -93,7 +93,11 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - raise NotImplementedError + from deepmd.jax.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file @property def deserialize_hook(self) -> Callable[[str, dict], None]: diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index ca915d61e3..aa41e35f69 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -1,12 +1,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import tensorflow as tf +from pathlib import ( + Path, +) + +import orbax.checkpoint as ocp from deepmd.jax.env import ( - jax2tf, + jax, nnx, ) from deepmd.jax.model.model import ( BaseModel, + get_model, ) from deepmd.jax.utils.network import ( ArrayAPIParam, @@ -23,41 +28,59 @@ def deserialize_to_file(model_file: str, data: dict) -> None: data : dict The dictionary to be deserialized. """ - if model_file.endswith(".saved_model"): - model = BaseModel.deserialize(data["model"]) - model_def_script = data.get("model_def_script", "{}") - my_model = tf.Module() - my_model.f = tf.function( - jax2tf.convert( - model, - polymorphic_shapes=[ - "(b, n, 3)", - "(b, n)", - "(b, 3, 3)", - "(b, f)", - "(b, a)", - "()", - ], - ), - autograph=False, - input_signature=[ - tf.TensorSpec([None, None, 3], tf.float64), - tf.TensorSpec([None, None], tf.int64), - tf.TensorSpec([None, 3, 3], tf.float64), - tf.TensorSpec([None, None], tf.float64), - tf.TensorSpec([None, None], tf.float64), - tf.TensorSpec([], tf.bool), - ], - ) - my_model.model_def_script = model_def_script - tf.saved_model.save( - my_model, - model_file, - options=tf.saved_model.SaveOptions(experimental_custom_gradients=True), - ) - elif model_file.endswith(".jax"): + if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) + model_def_script = data["model_def_script"] state = nnx.state(model, ArrayAPIParam) - nnx.display(state) + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + checkpointer.save( + Path(model_file).absolute(), + ocp.args.Composite( + state=ocp.args.StandardSave(state), + model_def_script=ocp.args.JsonSave(model_def_script), + ), + ) + else: + raise ValueError("JAX backend only supports converting .jax directory") + + +def serialize_from_file(model_file: str) -> dict: + """Serialize the model file to a dictionary. + + Parameters + ---------- + model_file : str + The model file to be serialized. + + Returns + ------- + dict + The serialized model data. + """ + if model_file.endswith(".jax"): + with ocp.Checkpointer( + ocp.CompositeCheckpointHandler("state", "model_def_script") + ) as checkpointer: + data = checkpointer.restore( + Path(model_file).absolute(), + ocp.args.Composite( + state=ocp.args.StandardRestore(), + model_def_script=ocp.args.JsonRestore(), + ), + ) + state = data.state + model_def_script = data.model_def_script + model = get_model(model_def_script) + model_dict = model.serialize() + data = { + "backend": "JAX", + "jax_version": jax.__version__, + "model": model_dict, + "model_def_script": model_def_script, + "@variables": {}, + } + return data else: - raise ValueError("JAX backend only supports converting .pth file") + raise ValueError("JAX backend only supports converting .jax directory") diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 71e4002128..0aaa0788ea 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import copy +import shutil import unittest from pathlib import ( Path, @@ -60,12 +61,14 @@ def save_data_to_model(self, model_file: str, data: dict) -> None: def tearDown(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() for ii in Path(".").glob(prefix + ".*"): - if Path(ii).exists(): + if Path(ii).is_file(): Path(ii).unlink() + elif Path(ii).is_dir(): + shutil.rmtree(ii) def test_data_equal(self): prefix = "test_consistent_io_" + self.__class__.__name__.lower() - for backend_name in ("tensorflow", "pytorch", "dpmodel"): + for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): with self.subTest(backend_name=backend_name): backend = Backend.get_backend(backend_name)() if not backend.is_available: @@ -80,6 +83,7 @@ def test_data_equal(self): "backend", "tf_version", "pt_version", + "jax_version", "@variables", # dpmodel only "software", From f0bc8b814f697fe220f871640ef57d6390f7c347 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Mon, 21 Oct 2024 18:37:45 -0400 Subject: [PATCH 11/17] add dependencies Signed-off-by: Jinzhe Zeng --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4a99cfaf8b..26f8e5ae6e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,8 @@ cu12 = [ jax = [ 'jax>=0.4.33;python_version>="3.10"', 'flax>=0.8.0;python_version>="3.10"', + 'orbax-checkpoint;python_version>="3.10"', + 'jax-ai-stack;python_version>="3.10"', ] [tool.deepmd_build_backend.scripts] From 8c25e8758d47b6eeff466db5c4fba9ee01b61279 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 23 Oct 2024 10:03:44 +0000 Subject: [PATCH 12/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/dpmodel/atomic_model/base_atomic_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 658ae1793d..6307b19f41 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import copy import math from typing import ( Optional, From 7e003b25da09d19ddb49624cf2510a4eaac5a044 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 23 Oct 2024 06:09:38 -0400 Subject: [PATCH 13/17] Remove jax2tf --- deepmd/jax/env.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/deepmd/jax/env.py b/deepmd/jax/env.py index 2d1c1454ed..5a5a7f6bf0 100644 --- a/deepmd/jax/env.py +++ b/deepmd/jax/env.py @@ -8,9 +8,6 @@ from flax import ( nnx, ) -from jax.experimental import ( - jax2tf, -) jax.config.update("jax_enable_x64", True) @@ -18,5 +15,4 @@ "jax", "jnp", "nnx", - "jax2tf", ] From 536bbcdd5fead8621f5d8964806dd1a1c40392ea Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 23 Oct 2024 15:38:28 -0400 Subject: [PATCH 14/17] Update pyproject.toml Signed-off-by: Jinzhe Zeng --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 26f8e5ae6e..2e95d70614 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -139,7 +139,8 @@ jax = [ 'jax>=0.4.33;python_version>="3.10"', 'flax>=0.8.0;python_version>="3.10"', 'orbax-checkpoint;python_version>="3.10"', - 'jax-ai-stack;python_version>="3.10"', + # The pinning of ml_dtypes may conflict with TF + # 'jax-ai-stack;python_version>="3.10"', ] [tool.deepmd_build_backend.scripts] From f90fc5275961b54d116afcc9da76488c6a132339 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 23 Oct 2024 15:42:26 -0400 Subject: [PATCH 15/17] update the model with the state Signed-off-by: Jinzhe Zeng --- deepmd/jax/utils/serialization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index aa41e35f69..a07dc5e2df 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -73,6 +73,7 @@ def serialize_from_file(model_file: str) -> dict: state = data.state model_def_script = data.model_def_script model = get_model(model_def_script) + nnx.update(model, state) model_dict = model.serialize() data = { "backend": "JAX", From 9b571d133fb861630c5996a2f9a5e70b484a7a6e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 23 Oct 2024 16:47:00 -0400 Subject: [PATCH 16/17] fix is_available Signed-off-by: Jinzhe Zeng --- source/tests/consistent/io/test_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/consistent/io/test_io.py b/source/tests/consistent/io/test_io.py index 0aaa0788ea..feafde234d 100644 --- a/source/tests/consistent/io/test_io.py +++ b/source/tests/consistent/io/test_io.py @@ -71,7 +71,7 @@ def test_data_equal(self): for backend_name in ("tensorflow", "pytorch", "dpmodel", "jax"): with self.subTest(backend_name=backend_name): backend = Backend.get_backend(backend_name)() - if not backend.is_available: + if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data) @@ -127,7 +127,7 @@ def test_deep_eval(self): rets = [] for backend_name in ("tensorflow", "pytorch", "dpmodel"): backend = Backend.get_backend(backend_name)() - if not backend.is_available: + if not backend.is_available(): continue reference_data = copy.deepcopy(self.data) self.save_data_to_model(prefix + backend.suffixes[0], reference_data) From fb3df8bcd7883a76652750e61d0f0c049f080f5e Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Wed, 23 Oct 2024 18:44:55 -0400 Subject: [PATCH 17/17] bugfix Signed-off-by: Jinzhe Zeng --- deepmd/jax/atomic_model/base_atomic_model.py | 3 +++ deepmd/jax/common.py | 14 ++++++++++++ deepmd/jax/descriptor/dpa1.py | 3 +++ deepmd/jax/descriptor/se_e2_a.py | 3 +++ deepmd/jax/fitting/fitting.py | 3 +++ deepmd/jax/utils/exclude_mask.py | 5 ++++ deepmd/jax/utils/serialization.py | 24 ++++++++++++++------ deepmd/jax/utils/type_embed.py | 3 +++ pyproject.toml | 2 +- 9 files changed, 52 insertions(+), 8 deletions(-) diff --git a/deepmd/jax/atomic_model/base_atomic_model.py b/deepmd/jax/atomic_model/base_atomic_model.py index 90920879c2..ffd58daf5e 100644 --- a/deepmd/jax/atomic_model/base_atomic_model.py +++ b/deepmd/jax/atomic_model/base_atomic_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from deepmd.jax.common import ( + ArrayAPIVariable, to_jax_array, ) from deepmd.jax.utils.exclude_mask import ( @@ -11,6 +12,8 @@ def base_atomic_model_set_attr(name, value): if name in {"out_bias", "out_std"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name == "pair_excl" and value is not None: value = PairExcludeMask(value.ntypes, value.exclude_types) elif name == "atom_excl" and value is not None: diff --git a/deepmd/jax/common.py b/deepmd/jax/common.py index 9c144a41d1..f372e97eb5 100644 --- a/deepmd/jax/common.py +++ b/deepmd/jax/common.py @@ -81,3 +81,17 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) return FlaxModule + + +class ArrayAPIVariable(nnx.Variable): + def __array__(self, *args, **kwargs): + return self.value.__array__(*args, **kwargs) + + def __array_namespace__(self, *args, **kwargs): + return self.value.__array_namespace__(*args, **kwargs) + + def __dlpack__(self, *args, **kwargs): + return self.value.__dlpack__(*args, **kwargs) + + def __dlpack_device__(self, *args, **kwargs): + return self.value.__dlpack_device__(*args, **kwargs) diff --git a/deepmd/jax/descriptor/dpa1.py b/deepmd/jax/descriptor/dpa1.py index 0528e4bb93..fef9bd5448 100644 --- a/deepmd/jax/descriptor/dpa1.py +++ b/deepmd/jax/descriptor/dpa1.py @@ -13,6 +13,7 @@ NeighborGatedAttentionLayer as NeighborGatedAttentionLayerDP, ) from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -65,6 +66,8 @@ class DescrptBlockSeAtten(DescrptBlockSeAttenDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"mean", "stddev"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name in {"embeddings", "embeddings_strip"}: if value is not None: value = NetworkCollection.deserialize(value.serialize()) diff --git a/deepmd/jax/descriptor/se_e2_a.py b/deepmd/jax/descriptor/se_e2_a.py index d1a6e9a8d9..31c147ad9d 100644 --- a/deepmd/jax/descriptor/se_e2_a.py +++ b/deepmd/jax/descriptor/se_e2_a.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -26,6 +27,8 @@ class DescrptSeA(DescrptSeADP): def __setattr__(self, name: str, value: Any) -> None: if name in {"dstd", "davg"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name in {"embeddings"}: if value is not None: value = NetworkCollection.deserialize(value.serialize()) diff --git a/deepmd/jax/fitting/fitting.py b/deepmd/jax/fitting/fitting.py index f979db4d41..cef1f667b3 100644 --- a/deepmd/jax/fitting/fitting.py +++ b/deepmd/jax/fitting/fitting.py @@ -6,6 +6,7 @@ from deepmd.dpmodel.fitting.dos_fitting import DOSFittingNet as DOSFittingNetDP from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -29,6 +30,8 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any: "aparam_inv_std", }: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) elif name == "emask": value = AtomExcludeMask(value.ntypes, value.exclude_types) elif name == "nets": diff --git a/deepmd/jax/utils/exclude_mask.py b/deepmd/jax/utils/exclude_mask.py index a6cf210f94..18d13d9400 100644 --- a/deepmd/jax/utils/exclude_mask.py +++ b/deepmd/jax/utils/exclude_mask.py @@ -6,6 +6,7 @@ from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -16,6 +17,8 @@ class AtomExcludeMask(AtomExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) return super().__setattr__(name, value) @@ -24,4 +27,6 @@ class PairExcludeMask(PairExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"type_mask"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) return super().__setattr__(name, value) diff --git a/deepmd/jax/utils/serialization.py b/deepmd/jax/utils/serialization.py index a07dc5e2df..43070f8a07 100644 --- a/deepmd/jax/utils/serialization.py +++ b/deepmd/jax/utils/serialization.py @@ -13,9 +13,6 @@ BaseModel, get_model, ) -from deepmd.jax.utils.network import ( - ArrayAPIParam, -) def deserialize_to_file(model_file: str, data: dict) -> None: @@ -31,14 +28,14 @@ def deserialize_to_file(model_file: str, data: dict) -> None: if model_file.endswith(".jax"): model = BaseModel.deserialize(data["model"]) model_def_script = data["model_def_script"] - state = nnx.state(model, ArrayAPIParam) + _, state = nnx.split(model) with ocp.Checkpointer( ocp.CompositeCheckpointHandler("state", "model_def_script") ) as checkpointer: checkpointer.save( Path(model_file).absolute(), ocp.args.Composite( - state=ocp.args.StandardSave(state), + state=ocp.args.StandardSave(state.to_pure_dict()), model_def_script=ocp.args.JsonSave(model_def_script), ), ) @@ -71,9 +68,22 @@ def serialize_from_file(model_file: str) -> dict: ), ) state = data.state + + # convert str "1" to int 1 key + def convert_str_to_int_key(item: dict): + for key, value in item.copy().items(): + if isinstance(value, dict): + convert_str_to_int_key(value) + if key.isdigit(): + item[int(key)] = item.pop(key) + + convert_str_to_int_key(state) + model_def_script = data.model_def_script - model = get_model(model_def_script) - nnx.update(model, state) + abstract_model = get_model(model_def_script) + graphdef, abstract_state = nnx.split(abstract_model) + abstract_state.replace_by_pure_dict(state) + model = nnx.merge(graphdef, abstract_state) model_dict = model.serialize() data = { "backend": "JAX", diff --git a/deepmd/jax/utils/type_embed.py b/deepmd/jax/utils/type_embed.py index 3143460244..30cd9f45a9 100644 --- a/deepmd/jax/utils/type_embed.py +++ b/deepmd/jax/utils/type_embed.py @@ -5,6 +5,7 @@ from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP from deepmd.jax.common import ( + ArrayAPIVariable, flax_module, to_jax_array, ) @@ -18,6 +19,8 @@ class TypeEmbedNet(TypeEmbedNetDP): def __setattr__(self, name: str, value: Any) -> None: if name in {"econf_tebd"}: value = to_jax_array(value) + if value is not None: + value = ArrayAPIVariable(value) if name in {"embedding_net"}: value = EmbeddingNet.deserialize(value.serialize()) return super().__setattr__(name, value) diff --git a/pyproject.toml b/pyproject.toml index 2e95d70614..6bc1065ac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ cu12 = [ ] jax = [ 'jax>=0.4.33;python_version>="3.10"', - 'flax>=0.8.0;python_version>="3.10"', + 'flax>=0.10.0;python_version>="3.10"', 'orbax-checkpoint;python_version>="3.10"', # The pinning of ml_dtypes may conflict with TF # 'jax-ai-stack;python_version>="3.10"',