Skip to content

Commit

Permalink
Fix(nnsk): NNSK class in nnsk.py to use the get() method when accessi…
Browse files Browse the repository at this point in the history
…ng the full orbital (#139)

* Update NNSK class in nnsk.py to use the get() method when accessing values in the full_basis_to_basis dictionary.

* fix digital error in test_emb_se2

* temp
  • Loading branch information
floatingCatty authored Apr 22, 2024
1 parent 8a888f3 commit c7f5331
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
10 changes: 5 additions & 5 deletions dptb/nn/nnsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def from_reference(
iasym, jasym = bond.split("-")
for ref_forbpair in ref_idp.orbpair_maps.keys():
rfiorb, rfjorb = ref_forbpair.split("-")
riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb]
riorb, rjorb = ref_idp.full_basis_to_basis[iasym].get(rfiorb), ref_idp.full_basis_to_basis[jasym].get(rfjorb)
fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb)
if fiorb is not None and fjorb is not None:
sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}")
Expand All @@ -547,7 +547,7 @@ def from_reference(
iasym, jasym = bond.split("-")
for ref_forbpair in ref_idp.orbpair_maps.keys():
rfiorb, rfjorb = ref_forbpair.split("-")
riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb]
riorb, rjorb = ref_idp.full_basis_to_basis[iasym].get(rfiorb), ref_idp.full_basis_to_basis[jasym].get(rfjorb)
fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb)
if fiorb is not None and fjorb is not None:
sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}")
Expand All @@ -566,7 +566,7 @@ def from_reference(
for asym in ref_idp.type_names:
if asym in idp.type_names:
for ref_forb in ref_idp.skonsite_maps.keys():
rorb = ref_idp.full_basis_to_basis[asym][ref_forb]
rorb = ref_idp.full_basis_to_basis[asym].get(ref_forb)
forb = idp.basis_to_full_basis[asym].get(rorb)
if forb is not None:
model.onsite_param.data[idp.chemical_symbol_to_type[asym],idp.skonsite_maps[forb]] = \
Expand All @@ -579,7 +579,7 @@ def from_reference(
for asym in ref_idp.type_names:
if asym in idp.type_names:
for ref_forb in ref_idp.sksoc_maps.keys():
rorb = ref_idp.full_basis_to_basis[asym][ref_forb]
rorb = ref_idp.full_basis_to_basis[asym].get(ref_forb)
forb = idp.basis_to_full_basis[asym].get(rorb)
if forb is not None:
model.soc_param.data[idp.chemical_symbol_to_type[asym],idp.sksoc_maps[forb]] = \
Expand All @@ -592,7 +592,7 @@ def from_reference(
iasym, jasym = bond.split("-")
for ref_forbpair in ref_idp.orbpair_maps.keys():
rfiorb, rfjorb = ref_forbpair.split("-")
riorb, rjorb = ref_idp.full_basis_to_basis[iasym][rfiorb], ref_idp.full_basis_to_basis[jasym][rfjorb]
riorb, rjorb = ref_idp.full_basis_to_basis[iasym].get(rfiorb), ref_idp.full_basis_to_basis[jasym].get(rfjorb)
fiorb, fjorb = idp.basis_to_full_basis[iasym].get(riorb), idp.basis_to_full_basis[jasym].get(rjorb)
if fiorb is not None and fjorb is not None:
sli = idp.orbpair_maps.get(f"{fiorb}-{fjorb}")
Expand Down
2 changes: 1 addition & 1 deletion dptb/tests/test_emb_se2.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_embedding(self):
assert torch.abs(re1.norm() - 0.06084440) < 1e-6
assert re1.norm() > 1e-6
re1 /= re1.norm()
assert torch.all(out_node[ii] == re1)
assert (out_node[ii] - re1).abs().max() < 1e-6

edge_out = torch.cat([out_node[edge_index[0]] + out_node[edge_index[1]], 1/edge_length.reshape(-1,1)], dim=-1) # [N_edge, D*D]

Expand Down

0 comments on commit c7f5331

Please sign in to comment.