Skip to content

Commit

Permalink
Add dp impl serialization test
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 1, 2024
1 parent 603128a commit a96cab0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
2 changes: 1 addition & 1 deletion deepmd/model_format/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
attention_layers = data.pop("attention_layers", None)
env_mat = data.pop("env_mat")
obj = cls(**data)
obj["davg"] = variables["davg"]
Expand Down
40 changes: 27 additions & 13 deletions source/tests/pt/test_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,23 @@ def test_consistency(
atol=atol,
err_msg=err_msg,
)
# dp impl serialization
dd3 = DPDescrptDPA1.deserialize(dd2.serialize())
rd3, _, _, _, _ = dd3.call(
self.coord_ext,
self.atype_ext,
self.nlist,
)
np.testing.assert_allclose(
rd0.detach().cpu().numpy(),
rd3,
rtol=rtol,
atol=atol,
err_msg=err_msg,
)
# old impl
if idt is False and prec == "float64":
dd3 = DescrptDPA1(
dd4 = DescrptDPA1(
self.rcut,
self.rcut_smth,
self.sel,
Expand All @@ -114,12 +128,12 @@ def test_consistency(
old_impl=True,
).to(env.DEVICE)
dd0_state_dict = dd0.se_atten.state_dict()
dd3_state_dict = dd3.se_atten.state_dict()
dd4_state_dict = dd4.se_atten.state_dict()

dd0_state_dict_attn = dd0.se_atten.dpa1_attention.state_dict()
dd3_state_dict_attn = dd3.se_atten.dpa1_attention.state_dict()
for i in dd3_state_dict:
dd3_state_dict[i] = (
dd4_state_dict_attn = dd4.se_atten.dpa1_attention.state_dict()
for i in dd4_state_dict:
dd4_state_dict[i] = (
dd0_state_dict[
i.replace(".deep_layers.", ".layers.")
.replace("filter_layers_old.", "filter_layers._networks.")
Expand All @@ -131,27 +145,27 @@ def test_consistency(
.clone()
)
if ".bias" in i and "attn_layer_norm" not in i:
dd3_state_dict[i] = dd3_state_dict[i].unsqueeze(0)
dd3.se_atten.load_state_dict(dd3_state_dict)
dd4_state_dict[i] = dd4_state_dict[i].unsqueeze(0)
dd4.se_atten.load_state_dict(dd4_state_dict)

dd0_state_dict_tebd = dd0.type_embedding.state_dict()
dd3_state_dict_tebd = dd3.type_embedding_old.state_dict()
for i in dd3_state_dict_tebd:
dd3_state_dict_tebd[i] = (
dd4_state_dict_tebd = dd4.type_embedding_old.state_dict()
for i in dd4_state_dict_tebd:
dd4_state_dict_tebd[i] = (
dd0_state_dict_tebd[i.replace("embedding.weight", "matrix")]
.detach()
.clone()
)
dd3.type_embedding_old.load_state_dict(dd3_state_dict_tebd)
dd4.type_embedding_old.load_state_dict(dd4_state_dict_tebd)

rd3, _, _, _, _ = dd3(
rd4, _, _, _, _ = dd4(
torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
)
np.testing.assert_allclose(
rd0.detach().cpu().numpy(),
rd3.detach().cpu().numpy(),
rd4.detach().cpu().numpy(),
rtol=rtol,
atol=atol,
err_msg=err_msg,
Expand Down

0 comments on commit a96cab0

Please sign in to comment.