From c7f533138b3637d190ec165837df896cec0abd2d Mon Sep 17 00:00:00 2001 From: Yinzhanghao Zhou <64253517+floatingCatty@users.noreply.github.com> Date: Mon, 22 Apr 2024 13:18:32 +0800 Subject: [PATCH] Fix(nnsk): NNSK class in nnsk.py to use the get() method when accessing the full orbital (#139) * Update NNSK class in nnsk.py to use the get() method when accessing values in the full_basis_to_basis dictionary. * fix digital error in test_emb_se2 * temp --- dptb/nn/nnsk.py | 10 +++++----- dptb/tests/test_emb_se2.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dptb/nn/nnsk.py b/dptb/nn/nnsk.py index 64cf23d2..0117bf63 100644 --- a/dptb/nn/nnsk.py +++ b/dptb/nn/nnsk.py @@ -528,7 +528,7 @@ def from_reference( iasym, jasym = bond.split("-") for ref_forbpair in ref_idp.orbpair_maps.keys(): rfiorb, rfjorb = ref_forbpair.split("-") - riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb] + riorb, rjorb = ref_idp.full_basis_to_basis[iasym].get(rfiorb), ref_idp.full_basis_to_basis[jasym].get(rfjorb) fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb) if fiorb is not None and fjorb is not None: sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}") @@ -547,7 +547,7 @@ def from_reference( iasym, jasym = bond.split("-") for ref_forbpair in ref_idp.orbpair_maps.keys(): rfiorb, rfjorb = ref_forbpair.split("-") - riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb] + riorb, rjorb = ref_idp.full_basis_to_basis[iasym].get(rfiorb), ref_idp.full_basis_to_basis[jasym].get(rfjorb) fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb) if fiorb is not None and fjorb is not None: sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}") @@ -566,7 +566,7 @@ def from_reference( for asym in ref_idp.type_names: if asym in idp.type_names: for ref_forb in ref_idp.skonsite_maps.keys(): - rorb = ref_idp.full_basis_to_basis[asym][ref_forb] + rorb = ref_idp.full_basis_to_basis[asym].get(ref_forb) forb = idp.basis_to_full_basis[asym].get(rorb) if forb is not None: model.onsite_param.data[idp.chemical_symbol_to_type[asym],idp.skonsite_maps[forb]] = \ @@ -579,7 +579,7 @@ def from_reference( for asym in ref_idp.type_names: if asym in idp.type_names: for ref_forb in ref_idp.sksoc_maps.keys(): - rorb = ref_idp.full_basis_to_basis[asym][ref_forb] + rorb = ref_idp.full_basis_to_basis[asym].get(ref_forb) forb = idp.basis_to_full_basis[asym].get(rorb) if forb is not None: model.soc_param.data[idp.chemical_symbol_to_type[asym],idp.sksoc_maps[forb]] = \ @@ -592,7 +592,7 @@ def from_reference( iasym, jasym = bond.split("-") for ref_forbpair in ref_idp.orbpair_maps.keys(): rfiorb, rfjorb = ref_forbpair.split("-") - riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb] + riorb, rjorb = ref_idp.full_basis_to_basis[iasym].get(rfiorb), ref_idp.full_basis_to_basis[jasym].get(rfjorb) fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb) if fiorb is not None and fjorb is not None: sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}") diff --git a/dptb/tests/test_emb_se2.py b/dptb/tests/test_emb_se2.py index c2f80f9f..b5ff140a 100644 --- a/dptb/tests/test_emb_se2.py +++ b/dptb/tests/test_emb_se2.py @@ -97,7 +97,7 @@ def test_embedding(self): assert torch.abs(re1.norm() - 0.06084440) < 1e-6 assert re1.norm() > 1e-6 re1 /= re1.norm() - assert torch.all(out_node[ii] == re1) + assert (out_node[ii] - re1).abs().max() < 1e-6 edge_out = torch.cat([out_node[edge_index[0]] + out_node[edge_index[1]], 1/edge_length.reshape(-1,1)], dim=-1) # [N_edge, D*D]