From a96cab01d60a782683a59c55ffacefdb24622fa9 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 2 Feb 2024 00:31:40 +0800 Subject: [PATCH] Add dp impl serialization test --- deepmd/model_format/dpa1.py | 2 +- source/tests/pt/test_dpa1.py | 40 ++++++++++++++++++++++++------------ 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/deepmd/model_format/dpa1.py b/deepmd/model_format/dpa1.py index e0cf193ce6..3ca28a4fae 100644 --- a/deepmd/model_format/dpa1.py +++ b/deepmd/model_format/dpa1.py @@ -390,7 +390,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA1": variables = data.pop("@variables") embeddings = data.pop("embeddings") type_embedding = data.pop("type_embedding") - attention_layers = data.pop("attention_layers") + attention_layers = data.pop("attention_layers", None) env_mat = data.pop("env_mat") obj = cls(**data) obj["davg"] = variables["davg"] diff --git a/source/tests/pt/test_dpa1.py b/source/tests/pt/test_dpa1.py index ffd5a25bd5..ae34305064 100644 --- a/source/tests/pt/test_dpa1.py +++ b/source/tests/pt/test_dpa1.py @@ -101,9 +101,23 @@ def test_consistency( atol=atol, err_msg=err_msg, ) + # dp impl serialization + dd3 = DPDescrptDPA1.deserialize(dd2.serialize()) + rd3, _, _, _, _ = dd3.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd3, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) # old impl if idt is False and prec == "float64": - dd3 = DescrptDPA1( + dd4 = DescrptDPA1( self.rcut, self.rcut_smth, self.sel, @@ -114,12 +128,12 @@ def test_consistency( old_impl=True, ).to(env.DEVICE) dd0_state_dict = dd0.se_atten.state_dict() - dd3_state_dict = dd3.se_atten.state_dict() + dd4_state_dict = dd4.se_atten.state_dict() dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict() - dd3_state_dict_attn = dd3.se_atten.dpa1_attention.state_dict() - for i in dd3_state_dict: - dd3_state_dict[i] = ( + dd4_state_dict_attn = dd4.se_atten.dpa1_attention.state_dict() + for i in dd4_state_dict: + dd4_state_dict[i] = ( dd0_state_dict[ i.replace(".deep_layers.", ".layers.") .replace("filter_layers_old.", "filter_layers._networks.") @@ -131,27 +145,27 @@ def test_consistency( .clone() ) if ".bias" in i and "attn_layer_norm" not in i: - dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0) - dd3.se_atten.load_state_dict(dd3_state_dict) + dd4_state_dict[i] = dd4_state_dict[i].unsqueeze(0) + dd4.se_atten.load_state_dict(dd4_state_dict) dd0_state_dict_tebd = dd0.type_embedding.state_dict() - dd3_state_dict_tebd = dd3.type_embedding_old.state_dict() - for i in dd3_state_dict_tebd: - dd3_state_dict_tebd[i] = ( + dd4_state_dict_tebd = dd4.type_embedding_old.state_dict() + for i in dd4_state_dict_tebd: + dd4_state_dict_tebd[i] = ( dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")] .detach() .clone() ) - dd3.type_embedding_old.load_state_dict(dd3_state_dict_tebd) + dd4.type_embedding_old.load_state_dict(dd4_state_dict_tebd) - rd3, _, _, _, _ = dd3( + rd4, _, _, _, _ = dd4( torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE), torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE), torch.tensor(self.nlist, dtype=int, device=env.DEVICE), ) np.testing.assert_allclose( rd0.detach().cpu().numpy(), - rd3.detach().cpu().numpy(), + rd4.detach().cpu().numpy(), rtol=rtol, atol=atol, err_msg=err_msg,