Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(2024Q1): fix lammps nlist sort with large sel #3994

Merged
merged 9 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/pt/model/model/ener_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def forward_lower(
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
extra_nlist_sort: bool = True,
):
model_ret = self.forward_common_lower(
extended_coord,
Expand All @@ -78,6 +79,7 @@ def forward_lower(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
extra_nlist_sort=extra_nlist_sort,
)
if self.get_fitting_net() is not None:
model_predict = {}
Expand Down
24 changes: 20 additions & 4 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
fparam: Optional[torch.Tensor] = None,
aparam: Optional[torch.Tensor] = None,
do_atomic_virial: bool = False,
extra_nlist_sort: bool = False,
):
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand All @@ -238,6 +239,8 @@
atomic parameter. nf x nloc x nda
do_atomic_virial
whether calculate atomic virial.
extra_nlist_sort
whether to forcibly sort the nlist.

Returns
-------
Expand All @@ -247,7 +250,9 @@
"""
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.view(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
nlist = self.format_nlist(

Check warning on line 253 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L253

Added line #L253 was not covered by tests
extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort
)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
extended_coord, fparam=fparam, aparam=aparam
)
Expand Down Expand Up @@ -347,6 +352,7 @@
extended_coord: torch.Tensor,
extended_atype: torch.Tensor,
nlist: torch.Tensor,
extra_nlist_sort: bool = False,
):
"""Format the neighbor list.

Expand All @@ -372,6 +378,8 @@
atomic type in extended region. nf x nall
nlist
neighbor list. nf x nloc x nsel
extra_nlist_sort
whether to forcibly sort the nlist.

Returns
-------
Expand All @@ -380,7 +388,12 @@

"""
mixed_types = self.mixed_types()
nlist = self._format_nlist(extended_coord, nlist, sum(self.get_sel()))
nlist = self._format_nlist(

Check warning on line 391 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L391

Added line #L391 was not covered by tests
extended_coord,
nlist,
sum(self.get_sel()),
extra_nlist_sort=extra_nlist_sort,
)
if not mixed_types:
nlist = nlist_distinguish_types(nlist, extended_atype, self.get_sel())
return nlist
Expand All @@ -390,6 +403,7 @@
extended_coord: torch.Tensor,
nlist: torch.Tensor,
nnei: int,
extra_nlist_sort: bool = False,
):
n_nf, n_nloc, n_nnei = nlist.shape
# nf x nall x 3
Expand All @@ -409,7 +423,9 @@
],
dim=-1,
)
elif n_nnei > nnei:

if n_nnei > nnei or (extra_nlist_sort and n_nnei <= nnei):
Fixed Show fixed Hide fixed
n_nf, n_nloc, n_nnei = nlist.shape

Check warning on line 428 in deepmd/pt/model/model/make_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/make_model.py#L427-L428

Added lines #L427 - L428 were not covered by tests
m_real_nei = nlist >= 0
nlist = torch.where(m_real_nei, nlist, 0)
# nf x nloc x 3
Expand All @@ -426,7 +442,7 @@
nlist = torch.gather(nlist, 2, nlist_mapping)
nlist = torch.where(rr > rcut, -1, nlist)
nlist = nlist[..., :nnei]
else: # n_nnei == nnei:
else: # not extra_nlist_sort and n_nnei <= nnei:
pass # great!
assert nlist.shape[-1] == nnei
return nlist
Expand Down
Loading