Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Add consistency test for ZBL between dp and pt #4292

Merged
merged 23 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions deepmd/dpmodel/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
)

from deepmd.dpmodel.atomic_model.linear_atomic_model import (
DPZBLLinearEnergyAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_model import (
DPModelCommon,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)

from .make_model import (
make_model,
)

DPEnergyModel_ = make_model(DPZBLLinearEnergyAtomicModel)


@BaseModel.register("zbl")
class DPZBLModel(DPEnergyModel_):
def __init__(
self,
*args,
**kwargs,
):
DPEnergyModel_.__init__(self, *args, **kwargs)


@classmethod
def update_sel(
cls,
train_data: DeepmdDataSystem,
type_map: Optional[list[str]],
local_jdata: dict,
) -> tuple[dict, Optional[float]]:
"""Update the selection and perform neighbor statistics.

Parameters
----------
train_data : DeepmdDataSystem
data used to do neighbor statistics
type_map : list[str], optional
The name of each type of atoms
local_jdata : dict
The local data refer to the current class

Returns
-------
dict
The updated local data
float
The minimum distance between two atoms
"""
local_jdata_cpy = local_jdata.copy()
local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel(
train_data, type_map, local_jdata["dpmodel"]
)
return local_jdata_cpy, min_nbor_dist
49 changes: 49 additions & 0 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
PairTabAtomicModel,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -8,6 +14,9 @@
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.model.dp_zbl_model import (
DPZBLModel,
)
from deepmd.dpmodel.model.ener_model import (
EnergyModel,
)
Expand Down Expand Up @@ -55,6 +64,44 @@ def get_standard_model(data: dict) -> EnergyModel:
)


def get_zbl_model(data: dict):
descriptor = DescrptSeA(**data["descriptor"])
fitting_type = data["fitting_net"].pop("type")
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
dim_descrpt=descriptor.get_dim_out(),
mixed_types=descriptor.mixed_types(),
**data["fitting_net"],
)
else:
raise ValueError(f"Unknown fitting type {fitting_type}")

dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"])
# pairtab
filepath = data["use_srtab"]
pt_model = PairTabAtomicModel(
filepath,
data["descriptor"]["rcut"],
data["descriptor"]["sel"],
type_map=data["type_map"],
)
anyangml marked this conversation as resolved.
Show resolved Hide resolved
anyangml marked this conversation as resolved.
Show resolved Hide resolved

rmin = data["sw_rmin"]
rmax = data["sw_rmax"]
anyangml marked this conversation as resolved.
Show resolved Hide resolved
atom_exclude_types = data.get("atom_exclude_types", [])
pair_exclude_types = data.get("pair_exclude_types", [])
return DPZBLModel(
dp_model,
pt_model,
rmin,
rmax,
type_map=data["type_map"],
atom_exclude_types=atom_exclude_types,
pair_exclude_types=pair_exclude_types,
)
anyangml marked this conversation as resolved.
Show resolved Hide resolved


def get_spin_model(data: dict) -> SpinModel:
"""Get a spin model from a dictionary.

Expand Down Expand Up @@ -100,6 +147,8 @@ def get_model(data: dict):
if model_type == "standard":
if "spin" in data:
return get_spin_model(data)
elif "use_srtab" in data:
return get_zbl_model(data)
else:
return get_standard_model(data)
else:
Expand Down
236 changes: 236 additions & 0 deletions source/tests/consistent/model/test_zbl_ener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from typing import (
Any,
)

import numpy as np

from deepmd.dpmodel.model.ener_model import EnergyModel as EnergyModelDP
from deepmd.dpmodel.model.model import get_model as get_model_dp
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)

from ..common import (
INSTALLED_JAX,
INSTALLED_PT,
INSTALLED_TF,
SKIP_FLAG,
CommonTest,
parameterized,
)
from .common import (
ModelTest,
)

if INSTALLED_PT:
from deepmd.pt.model.model import get_model as get_model_pt
from deepmd.pt.model.model.ener_model import EnergyModel as EnergyModelPT

else:
EnergyModelPT = None
if INSTALLED_TF:
from deepmd.tf.model.ener import EnerModel as EnergyModelTF
else:
EnergyModelTF = None
from deepmd.utils.argcheck import (
model_args,
)

if INSTALLED_JAX:
from deepmd.jax.model.ener_model import EnergyModel as EnergyModelJAX
from deepmd.jax.model.model import get_model as get_model_jax
else:
EnergyModelJAX = None


@parameterized(
(
[],
[[0, 1]],
),
(
[],
[1],
),
)
class TestEner(CommonTest, ModelTest, unittest.TestCase):
@property
def data(self) -> dict:
pair_exclude_types, atom_exclude_types = self.param
return {
"type_map": ["O", "H"],
"pair_exclude_types": pair_exclude_types,
"atom_exclude_types": atom_exclude_types,
"descriptor": {
"type": "se_e2_a",
"sel": [20, 20],
"rcut_smth": 0.50,
"rcut": 6.00,
"neuron": [
3,
6,
],
"resnet_dt": False,
"axis_neuron": 2,
"precision": "float64",
"type_one_side": True,
"seed": 1,
},
"fitting_net": {
"neuron": [
5,
5,
],
"resnet_dt": True,
"precision": "float64",
"seed": 1,
},
}

tf_class = EnergyModelTF
dp_class = EnergyModelDP
pt_class = EnergyModelPT
jax_class = EnergyModelJAX
args = model_args()

def get_reference_backend(self):
"""Get the reference backend.

We need a reference backend that can reproduce forces.
"""
if not self.skip_pt:
return self.RefBackend.PT
if not self.skip_tf:
return self.RefBackend.TF
if not self.skip_jax:
return self.RefBackend.JAX
if not self.skip_dp:
return self.RefBackend.DP
raise ValueError("No available reference")

@property
def skip_tf(self):
return (
self.data["pair_exclude_types"] != []
or self.data["atom_exclude_types"] != []
)

@property
def skip_jax(self):
return not INSTALLED_JAX

def pass_data_to_cls(self, cls, data) -> Any:
"""Pass data to the class."""
data = data.copy()
if cls is EnergyModelDP:
return get_model_dp(data)
elif cls is EnergyModelPT:
return get_model_pt(data)
elif cls is EnergyModelJAX:
return get_model_jax(data)
return cls(**data, **self.addtional_data)
anyangml marked this conversation as resolved.
Show resolved Hide resolved

def setUp(self):
CommonTest.setUp(self)

self.ntypes = 2
self.coords = np.array(
[
12.83,
2.56,
2.18,
12.09,
2.87,
2.74,
00.25,
3.32,
1.68,
3.36,
3.00,
1.81,
3.51,
2.51,
2.60,
4.27,
3.22,
1.56,
],
dtype=GLOBAL_NP_FLOAT_PRECISION,
).reshape(1, -1, 3)
self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1)
self.box = np.array(
[13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0],
dtype=GLOBAL_NP_FLOAT_PRECISION,
).reshape(1, 9)
self.natoms = np.array([6, 6, 2, 4], dtype=np.int32)

# TF requires the atype to be sort
idx_map = np.argsort(self.atype.ravel())
self.atype = self.atype[:, idx_map]
self.coords = self.coords[:, idx_map]

def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:
return self.build_tf_model(
obj,
self.natoms,
self.coords,
self.atype,
self.box,
suffix,
)

def eval_dp(self, dp_obj: Any) -> Any:
return self.eval_dp_model(
dp_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_pt(self, pt_obj: Any) -> Any:
return self.eval_pt_model(
pt_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def eval_jax(self, jax_obj: Any) -> Any:
return self.eval_jax_model(
jax_obj,
self.natoms,
self.coords,
self.atype,
self.box,
)

def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
# shape not matched. ravel...
if backend is self.RefBackend.DP:
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
SKIP_FLAG,
SKIP_FLAG,
)
elif backend is self.RefBackend.PT:
return (
ret["energy"].ravel(),
ret["atom_energy"].ravel(),
ret["force"].ravel(),
ret["virial"].ravel(),
)
elif backend is self.RefBackend.TF:
return (ret[0].ravel(), ret[1].ravel(), ret[2].ravel(), ret[3].ravel())
elif backend is self.RefBackend.JAX:
return (
ret["energy_redu"].ravel(),
ret["energy"].ravel(),
ret["energy_derv_r"].ravel(),
ret["energy_derv_c_redu"].ravel(),
)
raise ValueError(f"Unknown backend: {backend}")
anyangml marked this conversation as resolved.
Show resolved Hide resolved
Loading