Skip to content

Commit

Permalink
Feat: Add consistency test for ZBL between dp and pt (#4292)
Browse files Browse the repository at this point in the history
<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Release Notes

- **New Features**
  - Introduced `DPZBLModel`, enhancing energy modeling capabilities.
- Added `get_zbl_model` function for creating `DPZBLModel` from input
data.
- New `DPZBLLinearEnergyAtomicModel` class allows for complex
interactions between atomic models.

- **Bug Fixes**
- Corrected typographical errors in multiple test classes to improve
code clarity and consistency in method names.
- Updated model type attributes for `DPZBLModel` and `LinearEnergyModel`
to reflect accurate classifications.

- **Tests**
- Added comprehensive unit tests for energy models to ensure
functionality across various backends.
- Enhanced existing test classes with corrected method names for
improved accuracy.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
anyangml and pre-commit-ci[bot] authored Nov 1, 2024
1 parent a468819 commit 5c32147
Show file tree
Hide file tree
Showing 14 changed files with 356 additions and 11 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)


@BaseAtomicModel.register("linear")
class LinearEnergyAtomicModel(BaseAtomicModel):
"""Linear model make linear combinations of several existing models.
Expand Down Expand Up @@ -324,6 +325,7 @@ def is_aparam_nall(self) -> bool:
return False


@BaseAtomicModel.register("zbl")
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
"""Model linearly combine a list of AtomicModels.
Expand Down
66 changes: 66 additions & 0 deletions deepmd/dpmodel/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.atomic_model.linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_model import (
DPModelCommon,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)

from .make_model import (
make_model,
)

DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel)


@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_):
model_type = "zbl"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

@classmethod
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[list[str]],
local_jdata: dict,
) -> tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statistics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel(
train_data, type_map, local_jdata["dpmodel"]
)
return local_jdata_cpy, min_nbor_dist
53 changes: 53 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
PairTabAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -8,6 +17,9 @@
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_zbl_model import (
DPZBLModel,
)
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
Expand Down Expand Up @@ -55,6 +67,45 @@ def get_standard_model(data: dict) -> EnergyModel:
)


def get_zbl_model(data: dict) -> DPZBLModel:
data["descriptor"]["ntypes"] = len(data["type_map"])
descriptor = BaseDescriptor(**data["descriptor"])
fitting_type = data["fitting_net"].pop("type")
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")

dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"])
# pairtab
filepath = data["use_srtab"]
pt_model = PairTabAtomicModel(
filepath,
data["descriptor"]["rcut"],
data["descriptor"]["sel"],
type_map=data["type_map"],
)

rmin = data["sw_rmin"]
rmax = data["sw_rmax"]
atom_exclude_types = data.get("atom_exclude_types", [])
pair_exclude_types = data.get("pair_exclude_types", [])
return DPZBLModel(
dp_model,
pt_model,
rmin,
rmax,
type_map=data["type_map"],
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)


def get_spin_model(data: dict) -> SpinModel:
"""Get a spin model from a dictionary.
Expand Down Expand Up @@ -100,6 +151,8 @@ def get_model(data: dict):
if model_type == "standard":
if "spin" in data:
return get_spin_model(data)
elif "use_srtab" in data:
return get_zbl_model(data)
else:
return get_standard_model(data)
else:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/dp_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

@BaseModel.register("linear_ener")
class LinearEnergyModel(DPLinearModel_):
model_type = "ener"
model_type = "linear_ener"

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_):
model_type = "ener"
model_type = "zbl"

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
class CommonTest(ABC):
data: ClassVar[dict]
"""Arguments data."""
addtional_data: ClassVar[dict] = {}
additional_data: ClassVar[dict] = {}
"""Additional data that will not be checked."""
tf_class: ClassVar[Optional[type]]
"""TensorFlow model class."""
Expand Down Expand Up @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any:

def pass_data_to_cls(self, cls, data) -> Any:
"""Pass data to the class."""
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

@abstractmethod
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any:
return get_model_pt(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

def setUp(self):
CommonTest.setUp(self)
Expand Down
Loading

0 comments on commit 5c32147

Please sign in to comment.