Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pt): fix lammps nlist sort with large sel #3993

Merged
merged 13 commits into from
Jul 24, 2024
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return self.descriptor.need_sorted_nlist_for_lower()

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return any(model.need_sorted_nlist_for_lower() for model in self.models)

Check warning on line 101 in deepmd/dpmodel/atomic_model/linear_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/linear_atomic_model.py#L101

Added line #L101 was not covered by tests

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def mixed_types(self) -> bool:
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""

@abstractmethod
def fwd(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@
"""Returns whether the atomic model has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return False

Check warning on line 140 in deepmd/dpmodel/atomic_model/pairtab_atomic_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/atomic_model/pairtab_atomic_model.py#L140

Added line #L140 was not covered by tests

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def call(
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""


def extend_descrpt_stat(des, type_map, des_with_stat=None):
r"""
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.se_atten.need_sorted_nlist_for_lower()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down Expand Up @@ -952,6 +956,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False


class NeighborGatedAttention(NativeOP):
def __init__(
Expand Down
9 changes: 9 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,15 @@ def has_message_passing(self) -> bool:
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return any(
[
self.repinit.need_sorted_nlist_for_lower(),
self.repformers.need_sorted_nlist_for_lower(),
]
)
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return any(
descrpt.need_sorted_nlist_for_lower() for descrpt in self.descrpt_list
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def mixed_types(self) -> bool:
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""

@abstractmethod
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return True
iProzd marked this conversation as resolved.
Show resolved Hide resolved


# translated by GPT and modified
def get_residual(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
32 changes: 27 additions & 5 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ def call_lower(
"""
nframes, nall = extended_atype.shape[:2]
extended_coord = extended_coord.reshape(nframes, -1, 3)
nlist = self.format_nlist(extended_coord, extended_atype, nlist)
nlist = self.format_nlist(
extended_coord,
extended_atype,
nlist,
extra_nlist_sort=self.need_sorted_nlist_for_lower(),
)
cc_ext, _, fp, ap, input_prec = self.input_type_cast(
extended_coord, fparam=fparam, aparam=aparam
)
Expand Down Expand Up @@ -311,6 +316,7 @@ def format_nlist(
extended_coord: np.ndarray,
extended_atype: np.ndarray,
nlist: np.ndarray,
extra_nlist_sort: bool = False,
):
"""Format the neighbor list.

Expand All @@ -336,6 +342,8 @@ def format_nlist(
atomic type in extended region. nf x nall
nlist
neighbor list. nf x nloc x nsel
extra_nlist_sort
whether to forcibly sort the nlist.

Returns
-------
Expand All @@ -345,7 +353,12 @@ def format_nlist(
"""
n_nf, n_nloc, n_nnei = nlist.shape
mixed_types = self.mixed_types()
ret = self._format_nlist(extended_coord, nlist, sum(self.get_sel()))
ret = self._format_nlist(
extended_coord,
nlist,
sum(self.get_sel()),
extra_nlist_sort=extra_nlist_sort,
)
if not mixed_types:
ret = nlist_distinguish_types(ret, extended_atype, self.get_sel())
return ret
Expand All @@ -355,6 +368,7 @@ def _format_nlist(
extended_coord: np.ndarray,
nlist: np.ndarray,
nnei: int,
extra_nlist_sort: bool = False,
):
n_nf, n_nloc, n_nnei = nlist.shape
extended_coord = extended_coord.reshape([n_nf, -1, 3])
Expand All @@ -370,7 +384,9 @@ def _format_nlist(
],
axis=-1,
)
elif n_nnei > nnei:

if n_nnei > nnei or extra_nlist_sort:
n_nf, n_nloc, n_nnei = nlist.shape
# make a copy before revise
m_real_nei = nlist >= 0
ret = np.where(m_real_nei, nlist, 0)
Expand All @@ -384,9 +400,11 @@ def _format_nlist(
ret = np.take_along_axis(ret, ret_mapping, axis=2)
ret = np.where(rr > rcut, -1, ret)
ret = ret[..., :nnei]
else: # n_nnei == nnei:
# copy anyway...
# not extra_nlist_sort and n_nnei <= nnei:
elif n_nnei == nnei:
ret = nlist
else:
pass
assert ret.shape[-1] == nnei
return ret

Expand Down Expand Up @@ -483,6 +501,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the model needs sorted nlist when using `forward_lower`."""
return self.atomic_model.need_sorted_nlist_for_lower()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return self.descriptor.need_sorted_nlist_for_lower()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return any(model.need_sorted_nlist_for_lower() for model in self.models)
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def get_out_bias(self) -> torch.Tensor:
return self.out_bias

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
return False

def change_type_map(
self, type_map: List[str], model_with_new_type_stat=None
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ def forward(
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""

@abstractmethod
def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""


def make_default_type_embedding(
ntypes,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.se_atten.need_sorted_nlist_for_lower()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,15 @@ def has_message_passing(self) -> bool:
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return any(
[
self.repinit.need_sorted_nlist_for_lower(),
self.repformers.need_sorted_nlist_for_lower(),
]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
# the env_protection of repinit is the same as that of the repformer
Expand Down
6 changes: 6 additions & 0 deletions deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return any(
descrpt.need_sorted_nlist_for_lower() for descrpt in self.descrpt_list
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,3 +557,7 @@ def get_stats(self) -> Dict[str, StatItem]:
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return True
8 changes: 8 additions & 0 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.sea.has_message_passing()

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return self.sea.need_sorted_nlist_for_lower()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.sea.get_env_protection()
Expand Down Expand Up @@ -711,3 +715,7 @@ def forward(
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor block needs sorted nlist when using `forward_lower`."""
return False


class NeighborGatedAttention(nn.Module):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def need_sorted_nlist_for_lower(self) -> bool:
"""Returns whether the descriptor needs sorted nlist when using `forward_lower`."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
Loading