-
Notifications
You must be signed in to change notification settings - Fork 524
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: breaking: backend indepdent definition for dp model (#3208)
Features: - abstract base classes for atomic model, fitting and descriptor. - dp model format for atomic models - dp model format for models. - torch support for atomic model format. - torch support `fparam` and `aparam`. This pr also introduces the following updates: - support region and nlist in numpy code. - class decorator like `fitting_check_output` gives human readable class names. - support int types in precision dict. - fix descriptor interfaces. - refactor torch atomic model impl. introduces dirty hacks to be fixed. - provide `format_nlist` that format the nlist in forward_lower method. Known limitations: - torch atomic model has dirty hacks - interfaces for descriptor, fitting and model statistics was not considered, should be fixed in future PRs. Will be fixed - [x] dp model module path is a mess to be refactorized. - [x] nlist consistency should be checked. if not format nlist. - [x] doc strings. - [x] `fparam` and `aparam` support. --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
- Loading branch information
1 parent
032fa7d
commit eb9b2ef
Showing
65 changed files
with
2,738 additions
and
564 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from .common import ( | ||
DEFAULT_PRECISION, | ||
PRECISION_DICT, | ||
NativeOP, | ||
) | ||
from .model import ( | ||
DPAtomicModel, | ||
DPModel, | ||
) | ||
from .output_def import ( | ||
FittingOutputDef, | ||
ModelOutputDef, | ||
OutputVariableDef, | ||
fitting_check_output, | ||
get_deriv_name, | ||
get_reduce_name, | ||
model_check_output, | ||
) | ||
|
||
__all__ = [ | ||
"DPModel", | ||
"DPAtomicModel", | ||
"PRECISION_DICT", | ||
"DEFAULT_PRECISION", | ||
"NativeOP", | ||
"ModelOutputDef", | ||
"FittingOutputDef", | ||
"OutputVariableDef", | ||
"model_check_output", | ||
"fitting_check_output", | ||
"get_reduce_name", | ||
"get_deriv_name", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from .make_base_descriptor import ( | ||
make_base_descriptor, | ||
) | ||
from .se_e2_a import ( | ||
DescrptSeA, | ||
) | ||
|
||
__all__ = [ | ||
"DescrptSeA", | ||
"make_base_descriptor", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import numpy as np | ||
|
||
from .make_base_descriptor import ( | ||
make_base_descriptor, | ||
) | ||
|
||
BaseDescriptor = make_base_descriptor(np.ndarray, "call") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from abc import ( | ||
ABC, | ||
abstractclassmethod, | ||
abstractmethod, | ||
) | ||
from typing import ( | ||
List, | ||
Optional, | ||
) | ||
|
||
|
||
def make_base_descriptor( | ||
t_tensor, | ||
fwd_method_name: str = "forward", | ||
): | ||
"""Make the base class for the descriptor. | ||
Parameters | ||
---------- | ||
t_tensor | ||
The type of the tensor. used in the type hint. | ||
fwd_method_name | ||
Name of the forward method. For dpmodels, it should be "call". | ||
For torch models, it should be "forward". | ||
""" | ||
|
||
class BD(ABC): | ||
"""Base descriptor provides the interfaces of descriptor.""" | ||
|
||
@abstractmethod | ||
def get_rcut(self) -> float: | ||
"""Returns the cut-off radius.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_sel(self) -> List[int]: | ||
"""Returns the number of selected neighboring atoms for each type.""" | ||
pass | ||
|
||
def get_nsel(self) -> int: | ||
"""Returns the total number of selected neighboring atoms in the cut-off radius.""" | ||
return sum(self.get_sel()) | ||
|
||
def get_nnei(self) -> int: | ||
"""Returns the total number of selected neighboring atoms in the cut-off radius.""" | ||
return self.get_nsel() | ||
|
||
@abstractmethod | ||
def get_ntypes(self) -> int: | ||
"""Returns the number of element types.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_dim_out(self) -> int: | ||
"""Returns the output descriptor dimension.""" | ||
pass | ||
|
||
@abstractmethod | ||
def get_dim_emb(self) -> int: | ||
"""Returns the embedding dimension of g2.""" | ||
pass | ||
|
||
@abstractmethod | ||
def distinguish_types(self) -> bool: | ||
"""Returns if the descriptor requires a neighbor list that distinguish different | ||
atomic types or not. | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def compute_input_stats(self, merged): | ||
"""Update mean and stddev for descriptor elements.""" | ||
pass | ||
|
||
@abstractmethod | ||
def init_desc_stat(self, sumr, suma, sumn, sumr2, suma2): | ||
"""Initialize the model bias by the statistics.""" | ||
pass | ||
|
||
@abstractmethod | ||
def fwd( | ||
self, | ||
extended_coord, | ||
extended_atype, | ||
nlist, | ||
mapping: Optional[t_tensor] = None, | ||
): | ||
"""Calculate descriptor.""" | ||
pass | ||
|
||
@abstractmethod | ||
def serialize(self) -> dict: | ||
"""Serialize the obj to dict.""" | ||
pass | ||
|
||
@abstractclassmethod | ||
def deserialize(cls): | ||
"""Deserialize from a dict.""" | ||
pass | ||
|
||
setattr(BD, fwd_method_name, BD.fwd) | ||
delattr(BD, "fwd") | ||
|
||
return BD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
from .invar_fitting import ( | ||
InvarFitting, | ||
) | ||
from .make_base_fitting import ( | ||
make_base_fitting, | ||
) | ||
|
||
__all__ = [ | ||
"InvarFitting", | ||
"make_base_fitting", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# SPDX-License-Identifier: LGPL-3.0-or-later | ||
import numpy as np | ||
|
||
from .make_base_fitting import ( | ||
make_base_fitting, | ||
) | ||
|
||
BaseFitting = make_base_fitting(np.ndarray, fwd_method_name="call") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.