Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid descriptor #3365

Merged
merged 15 commits into from
Mar 1, 2024
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .hybrid import (
DescrptHybrid,
)
from .make_base_descriptor import (
make_base_descriptor,
)
Expand All @@ -12,5 +15,6 @@
__all__ = [
"DescrptSeA",
"DescrptSeR",
"DescrptHybrid",
"make_base_descriptor",
]
186 changes: 186 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Optional,
)

import numpy as np

from deepmd.dpmodel.common import (
NativeOP,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.version import (
check_version_compatibility,
)


@BaseDescriptor.register("hybrid")
class DescrptHybrid(BaseDescriptor, NativeOP):
"""Concate a list of descriptors to form a new descriptor.

Parameters
----------
list : list
Build a descriptor from the concatenation of the list of descriptors.
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
list: list,
) -> None:
super().__init__()

Check warning on line 37 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L37

Added line #L37 was not covered by tests
# warning: list is conflict with built-in list
descrpt_list = list
if descrpt_list == [] or descrpt_list is None:
raise RuntimeError(

Check warning on line 41 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L39-L41

Added lines #L39 - L41 were not covered by tests
"cannot build descriptor from an empty list of descriptors."
)
formatted_descript_list = []
for ii in descrpt_list:
if isinstance(ii, BaseDescriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
formatted_descript_list.append(BaseDescriptor(**ii))

Check warning on line 49 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L44-L49

Added lines #L44 - L49 were not covered by tests
else:
raise NotImplementedError
self.descrpt_list = formatted_descript_list
self.numb_descrpt = len(self.descrpt_list)
for ii in range(1, self.numb_descrpt):
assert (

Check warning on line 55 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L51-L55

Added lines #L51 - L55 were not covered by tests
self.descrpt_list[ii].get_ntypes() == self.descrpt_list[0].get_ntypes()
), f"number of atom types in {ii}th descrptor does not match others"
anyangml marked this conversation as resolved.
Show resolved Hide resolved
# if hybrid sel is larger than sub sel, the nlist needs to be cut for each type
hybrid_sel = self.get_sel()

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'hybrid_sel' is unnecessary as it is
redefined
before this value is used.
This assignment to 'hybrid_sel' is unnecessary as it is
redefined
before this value is used.
self.nlist_cut_idx: List[np.ndarray] = []
for ii in range(self.numb_descrpt):
sub_sel = self.descrpt_list[ii].get_sel()
start_idx = np.cumsum(np.pad(hybrid_sel, (1, 0), "constant"))[:-1]
end_idx = start_idx + np.array(sub_sel)
cut_idx = np.concatenate(

Check warning on line 65 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L59-L65

Added lines #L59 - L65 were not covered by tests
[range(ss, ee) for ss, ee in zip(start_idx, end_idx)]
)
self.nlist_cut_idx.append(cut_idx)

Check warning on line 68 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L68

Added line #L68 was not covered by tests

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
return np.max([descrpt.get_rcut() for descrpt in self.descrpt_list]).item()

Check warning on line 72 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L72

Added line #L72 was not covered by tests

def get_sel(self) -> List[int]:
"""Returns the number of selected atoms for each type."""
return np.max(

Check warning on line 76 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L76

Added line #L76 was not covered by tests
[descrpt.get_sel() for descrpt in self.descrpt_list], axis=0
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
).tolist()

def get_ntypes(self) -> int:
"""Returns the number of element types."""
return self.descrpt_list[0].get_ntypes()

Check warning on line 82 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L82

Added line #L82 was not covered by tests

def get_dim_out(self) -> int:
"""Returns the output dimension."""
return np.sum([descrpt.get_dim_out() for descrpt in self.descrpt_list]).item()

Check warning on line 86 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L86

Added line #L86 was not covered by tests

def get_dim_emb(self) -> int:
"""Returns the output dimension."""
return np.sum([descrpt.get_dim_emb() for descrpt in self.descrpt_list]).item()

Check warning on line 90 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L90

Added line #L90 was not covered by tests

def mixed_types(self):
"""Returns if the descriptor requires a neighbor list that distinguish different
atomic types or not.
"""
return all(descrpt.mixed_types() for descrpt in self.descrpt_list)

Check warning on line 96 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L96

Added line #L96 was not covered by tests
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved

def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
"""Update mean and stddev for descriptor elements."""
for descrpt in self.descrpt_list:
descrpt.compute_input_stats(merged, path)

Check warning on line 101 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L100-L101

Added lines #L100 - L101 were not covered by tests

def call(
self,
coord_ext,
atype_ext,
nlist,
mapping: Optional[np.ndarray] = None,
):
"""Compute the descriptor.

Parameters
----------
coord_ext
The extended coordinates of atoms. shape: nf x (nallx3)
atype_ext
The extended aotm types. shape: nf x nall
nlist
The neighbor list. shape: nf x nloc x nnei
mapping
The index mapping, not required by this descriptor.

Returns
-------
descriptor
The descriptor. shape: nf x nloc x (ng x axis_neuron)
gr
The rotationally equivariant and permutationally invariant single particle
representation. shape: nf x nloc x ng x 3. This descriptor returns None
g2
The rotationally invariant pair-partical representation.
this descriptor returns None
h2
The rotationally equivariant pair-partical representation.
this descriptor returns None
sw
The smooth switch function. this descriptor returns None
"""
out_descriptor = []
for descrpt, nci in zip(self.descrpt_list, self.nlist_cut_idx):

Check warning on line 140 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L139-L140

Added lines #L139 - L140 were not covered by tests
# cut the nlist to the correct length
odescriptor, _, _, _, _ = descrpt(

Check warning on line 142 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L142

Added line #L142 was not covered by tests
coord_ext, atype_ext, nlist[:, :, nci], mapping
)
out_descriptor.append(odescriptor)
out_descriptor = np.concatenate(out_descriptor, axis=-1)
return out_descriptor, None, None, None, None

Check warning on line 147 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L145-L147

Added lines #L145 - L147 were not covered by tests
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict:
"""Update the selection and perform neighbor statistics.

Parameters
----------
global_jdata : dict
The global data, containing the training section
local_jdata : dict
The local data refer to the current class
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["list"] = [

Check warning on line 161 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L160-L161

Added lines #L160 - L161 were not covered by tests
BaseDescriptor.update_sel(global_jdata, sub_jdata)
for sub_jdata in local_jdata["list"]
]
return local_jdata_cpy

Check warning on line 165 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L165

Added line #L165 was not covered by tests

def serialize(self) -> dict:
return {

Check warning on line 168 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L168

Added line #L168 was not covered by tests
"@class": "Descriptor",
"type": "hybrid",
"@version": 1,
"list": [descrpt.serialize() for descrpt in self.descrpt_list],
}

@classmethod
def deserialize(cls, data: dict) -> "DescrptHybrid":
data = data.copy()
class_name = data.pop("@class")
assert class_name == "Descriptor"
class_type = data.pop("type")
assert class_type == "hybrid"
check_version_compatibility(data.pop("@version"), 1, 1)
obj = cls(

Check warning on line 183 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L177-L183

Added lines #L177 - L183 were not covered by tests
list=[BaseDescriptor.deserialize(ii) for ii in data["list"]],
)
return obj

Check warning on line 186 in deepmd/dpmodel/descriptor/hybrid.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/descriptor/hybrid.py#L186

Added line #L186 was not covered by tests
2 changes: 2 additions & 0 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from .hybrid import (
DescrptBlockHybrid,
DescrptHybrid,
)
from .repformers import (
DescrptBlockRepformers,
Expand All @@ -39,6 +40,7 @@
"DescrptSeR",
"DescrptDPA1",
"DescrptDPA2",
"DescrptHybrid",
"prod_env_mat",
"DescrptGaussianLcc",
"DescrptBlockHybrid",
Expand Down
Loading
Loading