From 0199ad526ec37d20cc707f2f74818be605764c22 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 21:38:43 -0500 Subject: [PATCH 1/3] fix: consistent DPA-1 model (#4320) Fix #4022. Note that `smooth_type_embedding==True` is not consistent between TF and others. Also, fix several issues. ## Summary by CodeRabbit ## Release Notes - **New Features** - Enhanced configurability of descriptors with new optional parameters for type mapping and type count. - Introduction of a new class `DescrptSeAttenV2` for advanced attention mechanisms. - Added a new unit test framework for validating energy models across multiple backends. - **Bug Fixes** - Improved error handling in descriptor serialization methods to prevent unsupported operations. - **Documentation** - Updated backend documentation to include JAX support and clarify file extensions for various backends. - **Style** - Enhanced readability of error messages in fitting classes. - **Tests** - Comprehensive unit tests added for energy models across different machine learning frameworks. --------- Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/descriptor/hybrid.py | 6 + deepmd/dpmodel/model/model.py | 14 +- deepmd/jax/descriptor/__init__.py | 4 + deepmd/jax/model/model.py | 1 + deepmd/pt/model/task/fitting.py | 4 +- deepmd/tf/descriptor/hybrid.py | 6 + deepmd/tf/descriptor/se_atten.py | 123 ++++++++--- deepmd/tf/descriptor/se_atten_v2.py | 16 +- deepmd/tf/fit/ener.py | 6 +- deepmd/tf/model/model.py | 10 +- deepmd/tf/utils/type_embed.py | 5 + doc/backend.md | 4 - source/tests/consistent/model/test_dpa1.py | 236 +++++++++++++++++++++ 13 files changed, 389 insertions(+), 46 deletions(-) create mode 100644 source/tests/consistent/model/test_dpa1.py diff --git a/deepmd/dpmodel/descriptor/hybrid.py b/deepmd/dpmodel/descriptor/hybrid.py index 106fcaf11e..cde4534853 100644 --- a/deepmd/dpmodel/descriptor/hybrid.py +++ b/deepmd/dpmodel/descriptor/hybrid.py @@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, NativeOP): def __init__( self, list: list[Union[BaseDescriptor, dict[str, Any]]], + type_map: Optional[list[str]] = None, + ntypes: Optional[int] = None, # to be compat with input ) -> None: super().__init__() # warning: list is conflict with built-in list @@ -56,6 +58,10 @@ def __init__( if isinstance(ii, BaseDescriptor): formatted_descript_list.append(ii) elif isinstance(ii, dict): + ii = ii.copy() + # only pass if not already set + ii.setdefault("type_map", type_map) + ii.setdefault("ntypes", ntypes) formatted_descript_list.append(BaseDescriptor(**ii)) else: raise NotImplementedError diff --git a/deepmd/dpmodel/model/model.py b/deepmd/dpmodel/model/model.py index c29240214c..6dea9041fc 100644 --- a/deepmd/dpmodel/model/model.py +++ b/deepmd/dpmodel/model/model.py @@ -8,9 +8,6 @@ from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.dpmodel.descriptor.se_e2_a import ( - DescrptSeA, -) from deepmd.dpmodel.fitting.ener_fitting import ( EnergyFittingNet, ) @@ -39,16 +36,13 @@ def get_standard_model(data: dict) -> EnergyModel: data : dict The data to construct the model. """ - descriptor_type = data["descriptor"].pop("type") data["descriptor"]["type_map"] = data["type_map"] + data["descriptor"]["ntypes"] = len(data["type_map"]) fitting_type = data["fitting_net"].pop("type") data["fitting_net"]["type_map"] = data["type_map"] - if descriptor_type == "se_e2_a": - descriptor = DescrptSeA( - **data["descriptor"], - ) - else: - raise ValueError(f"Unknown descriptor type {descriptor_type}") + descriptor = BaseDescriptor( + **data["descriptor"], + ) if fitting_type == "ener": fitting = EnergyFittingNet( ntypes=descriptor.get_ntypes(), diff --git a/deepmd/jax/descriptor/__init__.py b/deepmd/jax/descriptor/__init__.py index 91a3032f8b..dc5282dd21 100644 --- a/deepmd/jax/descriptor/__init__.py +++ b/deepmd/jax/descriptor/__init__.py @@ -8,6 +8,9 @@ from deepmd.jax.descriptor.hybrid import ( DescrptHybrid, ) +from deepmd.jax.descriptor.se_atten_v2 import ( + DescrptSeAttenV2, +) from deepmd.jax.descriptor.se_e2_a import ( DescrptSeA, ) @@ -27,6 +30,7 @@ "DescrptSeT", "DescrptSeTTebd", "DescrptDPA1", + "DescrptSeAttenV2", "DescrptDPA2", "DescrptHybrid", ] diff --git a/deepmd/jax/model/model.py b/deepmd/jax/model/model.py index e636eba4c6..983815100c 100644 --- a/deepmd/jax/model/model.py +++ b/deepmd/jax/model/model.py @@ -37,6 +37,7 @@ def get_standard_model(data: dict): data = deepcopy(data) descriptor_type = data["descriptor"].pop("type") data["descriptor"]["type_map"] = data["type_map"] + data["descriptor"]["ntypes"] = len(data["type_map"]) fitting_type = data["fitting_net"].pop("type") data["fitting_net"]["type_map"] = data["type_map"] descriptor = BaseDescriptor.get_class_by_type(descriptor_type)( diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index bae46c2adb..b983c574e3 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -419,8 +419,8 @@ def _forward_common( if nd != self.dim_descrpt: raise ValueError( - "get an input descriptor of dim {nd}," - "which is not consistent with {self.dim_descrpt}." + f"get an input descriptor of dim {nd}," + f"which is not consistent with {self.dim_descrpt}." ) # check fparam dim, concate to input descriptor if self.numb_fparam > 0: diff --git a/deepmd/tf/descriptor/hybrid.py b/deepmd/tf/descriptor/hybrid.py index 3f20e7d856..e055efb31f 100644 --- a/deepmd/tf/descriptor/hybrid.py +++ b/deepmd/tf/descriptor/hybrid.py @@ -461,6 +461,8 @@ def update_sel( return local_jdata_cpy, min_nbor_dist def serialize(self, suffix: str = "") -> dict: + if hasattr(self, "type_embedding"): + raise NotImplementedError("hybrid + type embedding is not supported") return { "@class": "Descriptor", "type": "hybrid", @@ -485,4 +487,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid": for idx, ii in enumerate(data["list"]) ], ) + # search for type embedding + for ii in obj.descrpt_list: + if hasattr(ii, "type_embedding"): + raise NotImplementedError("hybrid + type embedding is not supported") return obj diff --git a/deepmd/tf/descriptor/se_atten.py b/deepmd/tf/descriptor/se_atten.py index 7bfb784419..b5896c9510 100644 --- a/deepmd/tf/descriptor/se_atten.py +++ b/deepmd/tf/descriptor/se_atten.py @@ -219,11 +219,11 @@ def __init__( if scaling_factor != 1.0: raise NotImplementedError("scaling_factor is not supported.") if not normalize: - raise NotImplementedError("normalize is not supported.") + raise NotImplementedError("Disabling normalize is not supported.") if temperature is not None: raise NotImplementedError("temperature is not supported.") if not concat_output_tebd: - raise NotImplementedError("concat_output_tebd is not supported.") + raise NotImplementedError("Disbaling concat_output_tebd is not supported.") if env_protection != 0.0: raise NotImplementedError("env_protection != 0.0 is not supported.") # to keep consistent with default value in this backends @@ -1866,7 +1866,11 @@ def deserialize(cls, data: dict, suffix: str = ""): if cls is not DescrptSeAtten: raise NotImplementedError(f"Not implemented in class {cls.__name__}") data = data.copy() - check_version_compatibility(data.pop("@version"), 1, 1) + if data["smooth_type_embedding"]: + raise RuntimeError( + "The implementation for smooth_type_embedding is inconsistent with other backends" + ) + check_version_compatibility(data.pop("@version"), 2, 1) data.pop("@class") data.pop("type") embedding_net_variables = cls.deserialize_network( @@ -1878,10 +1882,13 @@ def deserialize(cls, data: dict, suffix: str = ""): data.pop("env_mat") variables = data.pop("@variables") tebd_input_mode = data["tebd_input_mode"] - if tebd_input_mode in ["strip"]: - raise ValueError( - "Deserialization is unsupported for `tebd_input_mode='strip'` in the native model." - ) + type_embedding = TypeEmbedNet.deserialize( + data.pop("type_embedding"), suffix=suffix + ) + if "use_tebd_bias" not in data: + # v1 compatibility + data["use_tebd_bias"] = True + type_embedding.use_tebd_bias = data.pop("use_tebd_bias") descriptor = cls(**data) descriptor.embedding_net_variables = embedding_net_variables descriptor.attention_layer_variables = attention_layer_variables @@ -1891,6 +1898,17 @@ def deserialize(cls, data: dict, suffix: str = ""): descriptor.dstd = variables["dstd"].reshape( descriptor.ntypes, descriptor.ndescrpt ) + descriptor.type_embedding = type_embedding + if tebd_input_mode in ["strip"]: + type_one_side = data["type_one_side"] + two_side_embeeding_net_variables = cls.deserialize_network_strip( + data.pop("embeddings_strip"), + suffix=suffix, + type_one_side=type_one_side, + ) + descriptor.two_side_embeeding_net_variables = ( + two_side_embeeding_net_variables + ) return descriptor def serialize(self, suffix: str = "") -> dict: @@ -1906,10 +1924,9 @@ def serialize(self, suffix: str = "") -> dict: dict The serialized data """ - if self.stripped_type_embedding and type(self) is DescrptSeAtten: - # only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip' - raise NotImplementedError( - "serialization is unsupported by the native model when tebd_input_mode=='strip'" + if self.smooth: + raise RuntimeError( + "The implementation for smooth_type_embedding is inconsistent with other backends" ) # todo support serialization when tebd_input_mode=='strip' and type_one_side is True if self.stripped_type_embedding and self.type_one_side: @@ -1927,10 +1944,18 @@ def serialize(self, suffix: str = "") -> dict: assert self.davg is not None assert self.dstd is not None + tebd_dim = self.type_embedding.neuron[0] + if self.tebd_input_mode in ["concat"]: + if not self.type_one_side: + embd_input_dim = 1 + tebd_dim * 2 + else: + embd_input_dim = 1 + tebd_dim + else: + embd_input_dim = 1 data = { "@class": "Descriptor", - "type": "se_atten", - "@version": 1, + "type": "dpa1", + "@version": 2, "rcut": self.rcut_r, "rcut_smth": self.rcut_r_smth, "sel": self.sel_a, @@ -1952,9 +1977,7 @@ def serialize(self, suffix: str = "") -> dict: "embeddings": self.serialize_network( ntypes=self.ntypes, ndim=0, - in_dim=1 - if not hasattr(self, "embd_input_dim") - else self.embd_input_dim, + in_dim=embd_input_dim, neuron=self.filter_neuron, activation_function=self.activation_function_name, resnet_dt=self.filter_resnet_dt, @@ -1986,17 +2009,23 @@ def serialize(self, suffix: str = "") -> dict: "type_one_side": self.type_one_side, "spin": self.spin, } + data["type_embedding"] = self.type_embedding.serialize(suffix=suffix) + data["use_tebd_bias"] = self.type_embedding.use_tebd_bias + data["tebd_dim"] = tebd_dim + if len(self.type_embedding.neuron) > 1: + raise NotImplementedError( + "Only support single layer type embedding network" + ) if self.tebd_input_mode in ["strip"]: - assert ( - type(self) is not DescrptSeAtten - ), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'" + # assert ( + # type(self) is not DescrptSeAtten + # ), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'" data.update( { "embeddings_strip": self.serialize_network_strip( ntypes=self.ntypes, ndim=0, - in_dim=2 - * self.tebd_dim, # only DescrptDPA1Compat has this attribute + in_dim=2 * tebd_dim, neuron=self.filter_neuron, activation_function=self.activation_function_name, resnet_dt=self.filter_resnet_dt, @@ -2006,8 +2035,54 @@ def serialize(self, suffix: str = "") -> dict: ) } ) + # default values + data.update( + { + "scaling_factor": 1.0, + "normalize": True, + "temperature": None, + "concat_output_tebd": True, + "use_econf_tebd": False, + } + ) + data["attention_layers"] = self.update_attention_layers_serialize( + data["attention_layers"] + ) return data + def update_attention_layers_serialize(self, data: dict): + """Update the serialized data to be consistent with other backend references.""" + new_dict = { + "@class": "NeighborGatedAttention", + "@version": 1, + "scaling_factor": 1.0, + "normalize": True, + "temperature": None, + } + new_dict.update(data) + update_info = { + "nnei": self.nnei_a, + "embed_dim": self.filter_neuron[-1], + "hidden_dim": self.att_n, + "dotr": self.attn_dotr, + "do_mask": self.attn_mask, + "scaling_factor": 1.0, + "normalize": True, + "temperature": None, + "precision": self.filter_precision.name, + } + for layer_idx in range(self.attn_layer): + new_dict["attention_layers"][layer_idx].update(update_info) + new_dict["attention_layers"][layer_idx]["attention_layer"].update( + update_info + ) + new_dict["attention_layers"][layer_idx]["attention_layer"].update( + { + "num_heads": 1, + } + ) + return new_dict + class DescrptDPA1Compat(DescrptSeAtten): r"""Consistent version of the model for testing with other backend references. @@ -2433,17 +2508,11 @@ def serialize(self, suffix: str = "") -> dict: { "type": "dpa1", "@version": 2, - "tebd_dim": self.tebd_dim, "scaling_factor": self.scaling_factor, "normalize": self.normalize, "temperature": self.temperature, "concat_output_tebd": self.concat_output_tebd, "use_econf_tebd": self.use_econf_tebd, - "use_tebd_bias": self.use_tebd_bias, - "type_embedding": self.type_embedding.serialize(suffix), } ) - data["attention_layers"] = self.update_attention_layers_serialize( - data["attention_layers"] - ) return data diff --git a/deepmd/tf/descriptor/se_atten_v2.py b/deepmd/tf/descriptor/se_atten_v2.py index dc71f87523..69efe004c4 100644 --- a/deepmd/tf/descriptor/se_atten_v2.py +++ b/deepmd/tf/descriptor/se_atten_v2.py @@ -4,6 +4,9 @@ Optional, ) +from deepmd.tf.utils.type_embed import ( + TypeEmbedNet, +) from deepmd.utils.version import ( check_version_compatibility, ) @@ -127,10 +130,13 @@ def deserialize(cls, data: dict, suffix: str = ""): Model The deserialized model """ + raise RuntimeError( + "The implementation for smooth_type_embedding is inconsistent with other backends" + ) if cls is not DescrptSeAttenV2: raise NotImplementedError(f"Not implemented in class {cls.__name__}") data = data.copy() - check_version_compatibility(data.pop("@version"), 1, 1) + check_version_compatibility(data.pop("@version"), 2, 1) data.pop("@class") data.pop("type") embedding_net_variables = cls.deserialize_network( @@ -147,6 +153,13 @@ def deserialize(cls, data: dict, suffix: str = ""): suffix=suffix, type_one_side=type_one_side, ) + type_embedding = TypeEmbedNet.deserialize( + data.pop("type_embedding"), suffix=suffix + ) + if "use_tebd_bias" not in data: + # v1 compatibility + data["use_tebd_bias"] = True + type_embedding.use_tebd_bias = data.pop("use_tebd_bias") descriptor = cls(**data) descriptor.embedding_net_variables = embedding_net_variables descriptor.attention_layer_variables = attention_layer_variables @@ -157,6 +170,7 @@ def deserialize(cls, data: dict, suffix: str = ""): descriptor.dstd = variables["dstd"].reshape( descriptor.ntypes, descriptor.ndescrpt ) + descriptor.type_embedding = type_embedding return descriptor def serialize(self, suffix: str = "") -> dict: diff --git a/deepmd/tf/fit/ener.py b/deepmd/tf/fit/ener.py index 1ba0fe3dfb..040ec47cf7 100644 --- a/deepmd/tf/fit/ener.py +++ b/deepmd/tf/fit/ener.py @@ -242,6 +242,7 @@ def __init__( len(self.layer_name) == len(self.n_neuron) + 1 ), "length of layer_name should be that of n_neuron + 1" self.mixed_types = mixed_types + self.tebd_dim = 0 def get_numb_fparam(self) -> int: """Get the number of frame parameters.""" @@ -754,6 +755,8 @@ def build( outs = tf.reshape(outs, [-1]) tf.summary.histogram("fitting_net_output", outs) + # recover original dim_descrpt, which needs to be serialized + self.dim_descrpt = original_dim_descrpt return tf.reshape(outs, [-1]) def init_variables( @@ -908,7 +911,7 @@ def serialize(self, suffix: str = "") -> dict: "@version": 2, "var_name": "energy", "ntypes": self.ntypes, - "dim_descrpt": self.dim_descrpt, + "dim_descrpt": self.dim_descrpt + self.tebd_dim, "mixed_types": self.mixed_types, "dim_out": 1, "neuron": self.n_neuron, @@ -930,6 +933,7 @@ def serialize(self, suffix: str = "") -> dict: ndim=0 if self.mixed_types else 1, in_dim=( self.dim_descrpt + + self.tebd_dim + self.numb_fparam + (0 if self.use_aparam_as_mask else self.numb_aparam) ), diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 03211d49d5..51c66de65e 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -808,6 +808,12 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": check_version_compatibility(data.pop("@version", 2), 2, 1) descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix) fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix) + # pass descriptor type embedding to model + if descriptor.explicit_ntypes: + type_embedding = descriptor.type_embedding + fitting.dim_descrpt -= type_embedding.neuron[-1] + else: + type_embedding = None # BEGINE not supported keys data.pop("atom_exclude_types") data.pop("pair_exclude_types") @@ -818,6 +824,7 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": return cls( descriptor=descriptor, fitting_net=fitting, + type_embedding=type_embedding, **data, ) @@ -835,7 +842,8 @@ def serialize(self, suffix: str = "") -> dict: Name suffix to identify this descriptor """ if self.typeebd is not None: - raise NotImplementedError("type embedding is not supported") + self.descrpt.type_embedding = self.typeebd + self.fitting.tebd_dim = self.typeebd.neuron[-1] if self.spin is not None: raise NotImplementedError("spin is not supported") diff --git a/deepmd/tf/utils/type_embed.py b/deepmd/tf/utils/type_embed.py index 13d02a858c..9b7b17528d 100644 --- a/deepmd/tf/utils/type_embed.py +++ b/deepmd/tf/utils/type_embed.py @@ -6,6 +6,8 @@ Union, ) +import numpy as np + from deepmd.dpmodel.utils.network import ( EmbeddingNet, ) @@ -327,6 +329,9 @@ def serialize(self, suffix: str = "") -> dict: layer_idx = int(m[1]) - 1 weight_name = m[0] if weight_name == "idt": + if not isinstance(value, np.ndarray): + # ignore 0.0 set by deserialize + continue value = value.ravel() embedding_net[layer_idx][weight_name] = value diff --git a/doc/backend.md b/doc/backend.md index 3fb70bee90..dd20193d58 100644 --- a/doc/backend.md +++ b/doc/backend.md @@ -65,7 +65,3 @@ For example, when the model filename ends with `.pb` (the ProtoBuf file), DeePMD ## Convert model files between backends If a model is supported by two backends, one can use [`dp convert-backend`](./cli.rst) to convert the model file between these two backends. - -:::{warning} -Currently, only the `se_e2_a` model fully supports the backend conversion between TensorFlow {{ tensorflow_icon }} and PyTorch {{ pytorch_icon }}. -::: diff --git a/source/tests/consistent/model/test_dpa1.py b/source/tests/consistent/model/test_dpa1.py new file mode 100644 index 0000000000..32b523e1ba --- /dev/null +++ b/source/tests/consistent/model/test_dpa1.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +from typing import ( + Any, +) + +import numpy as np + +from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP +from deepmd.dpmodel.model.model import get_model as get_model_dp +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) + +from ..common import ( + INSTALLED_JAX, + INSTALLED_PT, + INSTALLED_TF, + SKIP_FLAG, + CommonTest, + parameterized, +) +from .common import ( + ModelTest, +) + +if INSTALLED_PT: + from deepmd.pt.model.model import get_model as get_model_pt + from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT +else: + EnergyModelPT = None +if INSTALLED_TF: + from deepmd.tf.model.ener import EnerModel as EnergyModelTF +else: + EnergyModelTF = None +from deepmd.utils.argcheck import ( + model_args, +) + +if INSTALLED_JAX: + from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX + from deepmd.jax.model.model import get_model as get_model_jax +else: + EnergyModelJAX = None + + +@parameterized( + ("strip", "concat"), # tebd_input_mode + # strip + smooth is inconsistent + (False,), # smooth +) +class TestDPA1Ener(CommonTest, ModelTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + tebd_input_mode, + smooth_type_embedding, + ) = self.param + return { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_atten", + "sel": 40, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [ + 3, + 6, + ], + "resnet_dt": False, + "axis_neuron": 2, + "seed": 1, + "attn": 128, + "attn_layer": 0, + "precision": "float64", + "tebd_input_mode": tebd_input_mode, + "smooth_type_embedding": smooth_type_embedding, + }, + "fitting_net": { + "neuron": [ + 5, + 5, + ], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + } + + tf_class = EnergyModelTF + dp_class = EnergyModelDP + pt_class = EnergyModelPT + jax_class = EnergyModelJAX + args = model_args() + + def get_reference_backend(self): + """Get the reference backend. + + We need a reference backend that can reproduce forces. + """ + if not self.skip_pt: + return self.RefBackend.PT + if not self.skip_tf: + return self.RefBackend.TF + if not self.skip_jax: + return self.RefBackend.JAX + if not self.skip_dp: + return self.RefBackend.DP + raise ValueError("No available reference") + + @property + def skip_jax(self): + return not INSTALLED_JAX + + def pass_data_to_cls(self, cls, data) -> Any: + """Pass data to the class.""" + data = data.copy() + if cls is EnergyModelDP: + return get_model_dp(data) + elif cls is EnergyModelPT: + return get_model_pt(data) + elif cls is EnergyModelJAX: + return get_model_jax(data) + return cls(**data, **self.additional_data) + + def setUp(self): + CommonTest.setUp(self) + + self.ntypes = 2 + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + + # TF requires the atype to be sort + idx_map = np.argsort(self.atype.ravel()) + self.atype = self.atype[:, idx_map] + self.coords = self.coords[:, idx_map] + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + return self.build_tf_model( + obj, + self.natoms, + self.coords, + self.atype, + self.box, + suffix, + ) + + def eval_dp(self, dp_obj: Any) -> Any: + return self.eval_dp_model( + dp_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + return self.eval_pt_model( + pt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def eval_jax(self, jax_obj: Any) -> Any: + return self.eval_jax_model( + jax_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + # shape not matched. ravel... + if backend is self.RefBackend.DP: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + SKIP_FLAG, + SKIP_FLAG, + SKIP_FLAG, + ) + elif backend is self.RefBackend.PT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ret["atom_virial"].ravel(), + ) + elif backend is self.RefBackend.TF: + return ( + ret[0].ravel(), + ret[1].ravel(), + ret[2].ravel(), + ret[3].ravel(), + ret[4].ravel(), + ) + elif backend is self.RefBackend.JAX: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ret["energy_derv_c"].ravel(), + ) + raise ValueError(f"Unknown backend: {backend}") From 15bb00c09cead06430194678300d1e906cc7e43e Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 8 Nov 2024 10:48:44 +0800 Subject: [PATCH 2/3] fix(pt/dp): make dpa2 convertable to `.dp` format (#4324) Fix #4295. BTW, I found that there seems no universal uts for `convert-backend` command. ## Summary by CodeRabbit - **New Features** - Updated `RepformerLayer` class to version 2, enhancing serialization and deserialization processes. - Introduced a new structure for residual variables within the serialized data, improving organization and clarity. - **Bug Fixes** - Adjusted version compatibility checks in the `deserialize` method to align with the new versioning scheme. --- deepmd/dpmodel/descriptor/repformers.py | 19 +++++++++++-------- deepmd/pt/model/descriptor/repformer_layer.py | 19 +++++++++++-------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 5422ff345e..5658a87a9d 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -1792,7 +1792,7 @@ def serialize(self) -> dict: """ data = { "@class": "RepformerLayer", - "@version": 1, + "@version": 2, "rcut": self.rcut, "rcut_smth": self.rcut_smth, "sel": self.sel, @@ -1877,9 +1877,11 @@ def serialize(self) -> dict: if self.update_style == "res_residual": data.update( { - "g1_residual": [to_numpy_array(aa) for aa in self.g1_residual], - "g2_residual": [to_numpy_array(aa) for aa in self.g2_residual], - "h2_residual": [to_numpy_array(aa) for aa in self.h2_residual], + "@variables": { + "g1_residual": [to_numpy_array(aa) for aa in self.g1_residual], + "g2_residual": [to_numpy_array(aa) for aa in self.g2_residual], + "h2_residual": [to_numpy_array(aa) for aa in self.h2_residual], + } } ) return data @@ -1894,7 +1896,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": The dict to deserialize from. """ data = data.copy() - check_version_compatibility(data.pop("@version"), 1, 1) + check_version_compatibility(data.pop("@version"), 2, 1) data.pop("@class") linear1 = data.pop("linear1") update_chnnl_2 = data["update_chnnl_2"] @@ -1915,9 +1917,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer": attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) g1_self_mlp = data.pop("g1_self_mlp", None) - g1_residual = data.pop("g1_residual", []) - g2_residual = data.pop("g2_residual", []) - h2_residual = data.pop("h2_residual", []) + variables = data.pop("@variables", {}) + g1_residual = variables.get("g1_residual", data.pop("g1_residual", [])) + g2_residual = variables.get("g2_residual", data.pop("g2_residual", [])) + h2_residual = variables.get("h2_residual", data.pop("h2_residual", [])) obj = cls(**data) obj.linear1 = NativeLayer.deserialize(linear1) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 31132f365e..b4fac5fce7 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -1295,7 +1295,7 @@ def serialize(self) -> dict: """ data = { "@class": "RepformerLayer", - "@version": 1, + "@version": 2, "rcut": self.rcut, "rcut_smth": self.rcut_smth, "sel": self.sel, @@ -1380,9 +1380,11 @@ def serialize(self) -> dict: if self.update_style == "res_residual": data.update( { - "g1_residual": [to_numpy_array(t) for t in self.g1_residual], - "g2_residual": [to_numpy_array(t) for t in self.g2_residual], - "h2_residual": [to_numpy_array(t) for t in self.h2_residual], + "@variables": { + "g1_residual": [to_numpy_array(t) for t in self.g1_residual], + "g2_residual": [to_numpy_array(t) for t in self.g2_residual], + "h2_residual": [to_numpy_array(t) for t in self.h2_residual], + } } ) return data @@ -1397,7 +1399,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": The dict to deserialize from. """ data = data.copy() - check_version_compatibility(data.pop("@version"), 1, 1) + check_version_compatibility(data.pop("@version"), 2, 1) data.pop("@class") linear1 = data.pop("linear1") update_chnnl_2 = data["update_chnnl_2"] @@ -1418,9 +1420,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer": attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) g1_self_mlp = data.pop("g1_self_mlp", None) - g1_residual = data.pop("g1_residual", []) - g2_residual = data.pop("g2_residual", []) - h2_residual = data.pop("h2_residual", []) + variables = data.pop("@variables", {}) + g1_residual = variables.get("g1_residual", data.pop("g1_residual", [])) + g2_residual = variables.get("g2_residual", data.pop("g2_residual", [])) + h2_residual = variables.get("h2_residual", data.pop("h2_residual", [])) obj = cls(**data) obj.linear1 = MLPLayer.deserialize(linear1) From 6c66be9bddeea8803c28ed3f44daf0d1877d257d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 7 Nov 2024 22:07:23 -0500 Subject: [PATCH 3/3] ci: pin array-api-strict to <2.1.1 (#4326) See #4325. ## Summary by CodeRabbit - **Chores** - Updated dependency constraints for improved reliability. - Expanded testing commands to enhance the testing suite. - Improved setup process with additional installation commands for specific environments. Signed-off-by: Jinzhe Zeng --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 802e920014..7d64d48e80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,7 +85,8 @@ test = [ "pytest-sugar", "pytest-split", "dpgui", - 'array-api-strict>=2;python_version>="3.9"', + # https://github.com/data-apis/array-api-strict/issues/85 + 'array-api-strict>=2,<2.1.1;python_version>="3.9"', ] docs = [ "sphinx>=3.1.1",