Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add consistency test for ZBL between dp and pt #4292

Merged
merged 23 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
)


@BaseAtomicModel.register("linear")
class LinearEnergyAtomicModel(BaseAtomicModel):
"""Linear model make linear combinations of several existing models.
Expand Down Expand Up @@ -324,6 +325,7 @@ def is_aparam_nall(self) -> bool:
return False


@BaseAtomicModel.register("zbl")
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
"""Model linearly combine a list of AtomicModels.
Expand Down
66 changes: 66 additions & 0 deletions deepmd/dpmodel/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.atomic_model.linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_model import (
DPModelCommon,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)

from .make_model import (
make_model,
)

DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel)


@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_):
model_type = "zbl"

def __init__(
self,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)

@classmethod
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[list[str]],
local_jdata: dict,
) -> tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.
Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statistics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class
Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel(

Check warning on line 63 in deepmd/dpmodel/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/dp_zbl_model.py#L62-L63

Added lines #L62 - L63 were not covered by tests
train_data, type_map, local_jdata["dpmodel"]
)
return local_jdata_cpy, min_nbor_dist

Check warning on line 66 in deepmd/dpmodel/model/dp_zbl_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/dp_zbl_model.py#L66

Added line #L66 was not covered by tests
53 changes: 53 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
PairTabAtomicModel,
)
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -8,6 +17,9 @@
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_zbl_model import (
DPZBLModel,
)
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
Expand Down Expand Up @@ -55,6 +67,45 @@
)


def get_zbl_model(data: dict) -> DPZBLModel:
data["descriptor"]["ntypes"] = len(data["type_map"])
descriptor = BaseDescriptor(**data["descriptor"])
fitting_type = data["fitting_net"].pop("type")
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")

Check warning on line 82 in deepmd/dpmodel/model/model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/model.py#L82

Added line #L82 was not covered by tests

dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"])
# pairtab
filepath = data["use_srtab"]
pt_model = PairTabAtomicModel(
filepath,
data["descriptor"]["rcut"],
data["descriptor"]["sel"],
type_map=data["type_map"],
)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
anyangml marked this conversation as resolved.
Show resolved Hide resolved

rmin = data["sw_rmin"]
rmax = data["sw_rmax"]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
atom_exclude_types = data.get("atom_exclude_types", [])
pair_exclude_types = data.get("pair_exclude_types", [])
return DPZBLModel(
dp_model,
pt_model,
rmin,
rmax,
type_map=data["type_map"],
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)


def get_spin_model(data: dict) -> SpinModel:
"""Get a spin model from a dictionary.

Expand Down Expand Up @@ -100,6 +151,8 @@
if model_type == "standard":
if "spin" in data:
return get_spin_model(data)
elif "use_srtab" in data:
return get_zbl_model(data)
else:
return get_standard_model(data)
else:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/dp_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

@BaseModel.register("linear_ener")
class LinearEnergyModel(DPLinearModel_):
model_type = "ener"
model_type = "linear_ener"

def __init__(
self,
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_):
model_type = "ener"
model_type = "zbl"

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
class CommonTest(ABC):
data: ClassVar[dict]
"""Arguments data."""
addtional_data: ClassVar[dict] = {}
additional_data: ClassVar[dict] = {}
"""Additional data that will not be checked."""
tf_class: ClassVar[Optional[type]]
"""TensorFlow model class."""
Expand Down Expand Up @@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any:

def pass_data_to_cls(self, cls, data) -> Any:
"""Pass data to the class."""
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

@abstractmethod
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_polar.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def setUp(self):
self.atype.sort()

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/fitting/test_property.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def setUp(self):
).reshape(-1, 1)

@property
def addtional_data(self) -> dict:
def additional_data(self) -> dict:
(
resnet_dt,
precision,
Expand Down
2 changes: 1 addition & 1 deletion source/tests/consistent/model/test_ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def pass_data_to_cls(self, cls, data) -> Any:
return get_model_pt(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.addtional_data)
return cls(**data, **self.additional_data)

def setUp(self):
CommonTest.setUp(self)
Expand Down
Loading