Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: directional nlist #4052

Merged
merged 7 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 123 additions & 5 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@
).view(batch_size, nall * 3)
if isinstance(sel, int):
sel = [sel]
nsel = sum(sel)
# nloc x 3
coord0 = coord1[:, : nloc * 3]
# nloc x nall x 3
Expand All @@ -126,8 +125,26 @@
# nloc x (nall-1)
rr = rr[:, :, 1:]
nlist = nlist[:, :, 1:]

return _trim_mask_distinguish_nlist(
is_vir, atype, rr, nlist, rcut, sel, distinguish_types
)


def _trim_mask_distinguish_nlist(
is_vir_cntl: torch.Tensor,
atype_neig: torch.Tensor,
rr: torch.Tensor,
nlist: torch.Tensor,
rcut: float,
sel: List[int],
distinguish_types: bool,
) -> torch.Tensor:
"""Trim the size of nlist, mask if any central atom is virtual, distinguish types if necessary."""
nsel = sum(sel)
# nloc x nsel
nnei = rr.shape[2]
batch_size, nloc, nnei = rr.shape
assert batch_size == is_vir_cntl.shape[0]
if nsel <= nnei:
rr = rr[:, :, :nsel]
nlist = nlist[:, :, :nsel]
Expand All @@ -147,15 +164,116 @@
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
nlist = torch.where(
torch.logical_or((rr > rcut), is_vir[:, :nloc, None]), -1, nlist
torch.logical_or((rr > rcut), is_vir_cntl[:, :nloc, None]), -1, nlist
)

if distinguish_types:
return nlist_distinguish_types(nlist, atype, sel)
return nlist_distinguish_types(nlist, atype_neig, sel)
else:
return nlist


def build_directional_neighbor_list(
coord_cntl: torch.Tensor,
atype_cntl: torch.Tensor,
coord_neig: torch.Tensor,
atype_neig: torch.Tensor,
rcut: float,
sel: Union[int, List[int]],
distinguish_types: bool = True,
) -> torch.Tensor:
"""Build directional neighbor list.

With each central atom, all the neighbor atoms in the cut-off radius will
be recorded in the neighbor list. The maximum neighbors is nsel. If the real
number of neighbors is larger than nsel, the neighbors will be sorted with the
distance and the first nsel neighbors are kept.

Important: the central and neighboring atoms are assume to be different atoms.

Parameters
----------
coord_central : torch.Tensor
coordinates of central atoms. assumed to be local atoms.
shape [batch_size, nloc_central x 3]
atype_central : torch.Tensor
atomic types of central atoms. shape [batch_size, nloc_central]
if type < 0 the atom is treated as virtual atoms.
coord_neighbor : torch.Tensor
extended coordinates of neighbors atoms. shape [batch_size, nall_neighbor x 3]
atype_central : torch.Tensor
extended atomic types of neighbors atoms. shape [batch_size, nall_neighbor]
if type < 0 the atom is treated as virtual atoms.
rcut : float
cut-off radius
sel : int or List[int]
maximal number of neighbors (of each type).
if distinguish_types==True, nsel should be list and
the length of nsel should be equal to number of
types.
distinguish_types : bool
distinguish different types.

Returns
-------
neighbor_list : torch.Tensor
Neighbor list of shape [batch_size, nloc_central, nsel], the neighbors
are stored in an ascending order. If the number of neighbors is less than nsel,
the positions are masked with -1. The neighbor list of an atom looks like
|------ nsel ------|
xx xx xx xx -1 -1 -1
if distinguish_types==True and we have two types
|---- nsel[0] -----| |---- nsel[1] -----|
xx xx xx xx -1 -1 -1 xx xx xx -1 -1 -1 -1
For virtual atoms all neighboring positions are filled with -1.
"""
batch_size = coord_cntl.shape[0]
coord_cntl = coord_cntl.view(batch_size, -1)
nloc_cntl = coord_cntl.shape[1] // 3
coord_neig = coord_neig.view(batch_size, -1)
nall_neig = coord_neig.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if coord_neig.numel() > 0:
xmax = torch.max(coord_cntl) + 2.0 * rcut
else:
xmax = (

Check warning on line 239 in deepmd/pt/utils/nlist.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/nlist.py#L239

Added line #L239 was not covered by tests
torch.zeros(1, dtype=coord_neig.dtype, device=coord_neig.device)
+ 2.0 * rcut
)
# nf x nloc
is_vir_cntl = atype_cntl < 0
# nf x nall
is_vir_neig = atype_neig < 0
# nf x nloc x 3
coord_cntl = coord_cntl.view(batch_size, nloc_cntl, 3)
# nf x nall x 3
coord_neig = torch.where(
is_vir_neig[:, :, None], xmax, coord_neig.view(batch_size, nall_neig, 3)
).view(batch_size, nall_neig, 3)
# nsel
if isinstance(sel, int):
sel = [sel]
# nloc x nall x 3
diff = coord_neig[:, None, :, :] - coord_cntl[:, :, None, :]
assert list(diff.shape) == [batch_size, nloc_cntl, nall_neig, 3]
# nloc x nall
rr = torch.linalg.norm(diff, dim=-1)
rr, nlist = torch.sort(rr, dim=-1)

# We assume that the central and neighbor atoms are diffferent,
# thus we do not need to exclude self-neighbors.
# # if central atom has two zero distances, sorting sometimes can not exclude itself
# rr -= torch.eye(nloc_cntl, nall_neig, dtype=rr.dtype, device=rr.device).unsqueeze(0)
# rr, nlist = torch.sort(rr, dim=-1)
# # nloc x (nall-1)
# rr = rr[:, :, 1:]
# nlist = nlist[:, :, 1:]

return _trim_mask_distinguish_nlist(
is_vir_cntl, atype_neig, rr, nlist, rcut, sel, distinguish_types
)


def nlist_distinguish_types(
nlist: torch.Tensor,
atype: torch.Tensor,
Expand Down
70 changes: 69 additions & 1 deletion source/tests/pt/model/test_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
env,
)
from deepmd.pt.utils.nlist import (
build_directional_neighbor_list,
build_multiple_neighbor_list,
build_neighbor_list,
extend_coord_with_ghosts,
Expand Down Expand Up @@ -62,6 +63,7 @@ def test_build_notype(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
# test normal sel
nlist = build_neighbor_list(
ecoord,
eatype,
Expand All @@ -70,14 +72,29 @@ def test_build_notype(self):
sum(self.nsel),
distinguish_types=False,
)
torch.testing.assert_close(nlist[0], nlist[1])
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc[nlist_mask] = -1
torch.testing.assert_close(
torch.sort(nlist_loc, dim=-1)[0],
torch.sort(self.ref_nlist, dim=-1)[0],
)
# test a very large sel
nlist = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
sum(self.nsel) + 300, # +300, real nnei==224
distinguish_types=False,
)
nlist_mask = nlist[0] == -1
nlist_loc = mapping[0][nlist[0]]
nlist_loc[nlist_mask] = -1
torch.testing.assert_close(
torch.sort(nlist_loc, descending=True, dim=-1)[0][:, : sum(self.nsel)],
torch.sort(self.ref_nlist, descending=True, dim=-1)[0],
)

def test_build_type(self):
ecoord, eatype, mapping = extend_coord_with_ghosts(
Expand Down Expand Up @@ -218,3 +235,54 @@ def test_extend_coord(self):
rtol=self.prec,
atol=self.prec,
)

def test_build_directional_nlist(self):
"""Directional nlist is tested against the standard nlist implementation."""
ecoord, eatype, mapping = extend_coord_with_ghosts(
self.coord, self.atype, self.cell, self.rcut
)
for distinguish_types, mysel in zip([True, False], [sum(self.nsel), 300]):
# full neighbor list
nlist_full = build_neighbor_list(
ecoord,
eatype,
self.nloc,
self.rcut,
sum(self.nsel),
distinguish_types=distinguish_types,
)
# central as part of the system
nlist = build_directional_neighbor_list(
ecoord[:, 3:6],
eatype[:, 1:2],
torch.concat(
[
ecoord[:, 0:3],
torch.zeros(
[self.nf, 3], dtype=dtype, device=env.DEVICE
), # placeholder
ecoord[:, 6:],
],
dim=1,
),
torch.concat(
[
eatype[:, 0:1],
-1
* torch.ones(
[self.nf, 1], dtype=int, device=env.DEVICE
), # placeholder
eatype[:, 2:],
],
dim=1,
),
self.rcut,
mysel,
distinguish_types=distinguish_types,
)
torch.testing.assert_close(nlist[0], nlist[1])
torch.testing.assert_close(nlist[0], nlist[2])
torch.testing.assert_close(
torch.sort(nlist[0], descending=True, dim=-1)[0][:, : sum(self.nsel)],
torch.sort(nlist_full[0][1:2], descending=True, dim=-1)[0],
)