Skip to content

Commit

Permalink
add three body for repinit
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Sep 13, 2024
1 parent 498fc24 commit 5979086
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 9 deletions.
169 changes: 165 additions & 4 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@
DescrptBlockRepformers,
RepformerLayer,
)
from .se_t_tebd import (
DescrptBlockSeTTebd,
)


class RepinitArgs:
Expand All @@ -75,6 +78,11 @@ def __init__(
activation_function="tanh",
resnet_dt: bool = False,
type_one_side: bool = False,
use_three_body: bool = False,
three_body_neuron: List[int] = [2, 4, 8],
three_body_sel: int = 40,
three_body_rcut: float = 4.0,
three_body_rcut_smth: float = 0.5,
):
r"""The constructor for the RepinitArgs class which defines the parameters of the repinit block in DPA2 descriptor.
Expand Down Expand Up @@ -104,6 +112,19 @@ def __init__(
Whether to use a "Timestep" in the skip connection.
type_one_side : bool, optional
Whether to use one-side type embedding.
use_three_body : bool, optional
Whether to concatenate three-body representation in the output descriptor.
three_body_neuron : list, optional
Number of neurons in each hidden layers of the three-body embedding net.
When two layers are of the same size or one layer is twice as large as the previous layer,
a skip connection is built.
three_body_sel : int, optional
Maximally possible number of selected neighbors in the three-body representation.
three_body_rcut : float, optional
The cut-off radius in the three-body representation.
three_body_rcut_smth : float, optional
Where to start smoothing in the three-body representation.
For example the 1/r term is smoothed from three_body_rcut to three_body_rcut_smth.
"""
self.rcut = rcut
self.rcut_smth = rcut_smth
Expand All @@ -116,6 +137,11 @@ def __init__(
self.activation_function = activation_function
self.resnet_dt = resnet_dt
self.type_one_side = type_one_side
self.use_three_body = use_three_body
self.three_body_neuron = three_body_neuron
self.three_body_sel = three_body_sel
self.three_body_rcut = three_body_rcut
self.three_body_rcut_smth = three_body_rcut_smth

def __getitem__(self, key):
if hasattr(self, key):
Expand All @@ -136,6 +162,11 @@ def serialize(self) -> dict:
"activation_function": self.activation_function,
"resnet_dt": self.resnet_dt,
"type_one_side": self.type_one_side,
"use_three_body": self.use_three_body,
"three_body_neuron": self.three_body_neuron,
"three_body_sel": self.three_body_sel,
"three_body_rcut": self.three_body_rcut,
"three_body_rcut_smth": self.three_body_rcut_smth,
}

@classmethod
Expand Down Expand Up @@ -431,6 +462,27 @@ def init_subclass_params(sub_data, sub_class):
type_one_side=self.repinit_args.type_one_side,
seed=child_seed(seed, 0),
)
self.use_three_body = self.repinit_args.use_three_body
if self.repinit_args.use_three_body:
self.repinit_three_body = DescrptBlockSeTTebd(
self.repinit_args.three_body_rcut,
self.repinit_args.three_body_rcut_smth,
self.repinit_args.three_body_sel,
ntypes,
neuron=self.repinit_args.three_body_neuron,
tebd_dim=self.repinit_args.tebd_dim,
tebd_input_mode=self.repinit_args.tebd_input_mode,
set_davg_zero=self.repinit_args.set_davg_zero,
exclude_types=exclude_types,
env_protection=env_protection,
activation_function=self.repinit_args.activation_function,
precision=precision,
resnet_dt=self.repinit_args.resnet_dt,
smooth=smooth,
seed=child_seed(seed, 5),
)
else:
self.repinit_three_body = None
self.repformers = DescrptBlockRepformers(
self.repformer_args.rcut,
self.repformer_args.rcut_smth,
Expand Down Expand Up @@ -469,6 +521,37 @@ def init_subclass_params(sub_data, sub_class):
ln_eps=self.repformer_args.ln_eps,
seed=child_seed(seed, 1),
)
if not self.use_three_body:
self.rcut_list = [self.repformers.get_rcut(), self.repinit.get_rcut()]
self.nsel_list = [self.repformers.get_nsel(), self.repinit.get_nsel()]
else:
if (
self.repinit_three_body.get_rcut() >= self.repformers.get_rcut()
and self.repinit_three_body.get_nsel() >= self.repformers.get_nsel()
):
self.rcut_list = [
self.repformers.get_rcut(),
self.repinit_three_body.get_rcut(),
self.repinit.get_rcut(),
]
self.nsel_list = [
self.repformers.get_nsel(),
self.repinit_three_body.get_nsel(),
self.repinit.get_nsel(),
]
else:
self.rcut_list = [
self.repinit_three_body.get_rcut(),
self.repformers.get_rcut(),
self.repinit.get_rcut(),
]
self.nsel_list = [
self.repinit_three_body.get_nsel(),
self.repformers.get_nsel(),
self.repinit.get_nsel(),
]
self.rcut_list = sorted(self.rcut_list)
self.nsel_list = sorted(self.nsel_list)
self.use_econf_tebd = use_econf_tebd
self.use_tebd_bias = use_tebd_bias
self.type_map = type_map
Expand All @@ -491,11 +574,16 @@ def init_subclass_params(sub_data, sub_class):
self.trainable = trainable
self.add_tebd_to_repinit_out = add_tebd_to_repinit_out

if self.repinit.dim_out == self.repformers.dim_in:
self.repinit_out_dim = self.repinit.dim_out
if self.repinit_args.use_three_body:
assert self.repinit_three_body is not None
self.repinit_out_dim += self.repinit_three_body.dim_out

if self.repinit_out_dim == self.repformers.dim_in:
self.g1_shape_tranform = Identity()
else:
self.g1_shape_tranform = NativeLayer(
self.repinit.dim_out,
self.repinit_out_dim,
self.repformers.dim_in,
bias=False,
precision=precision,
Expand Down Expand Up @@ -603,6 +691,7 @@ def change_type_map(
self.ntypes = len(type_map)
repinit = self.repinit
repformers = self.repformers
repinit_three_body = self.repinit_three_body
if has_new_type:
# the avg and std of new types need to be updated
extend_descrpt_stat(
Expand All @@ -619,6 +708,14 @@ def change_type_map(
if model_with_new_type_stat is not None
else None,
)
if self.use_three_body:
extend_descrpt_stat(
repinit_three_body,
type_map,
des_with_stat=model_with_new_type_stat.repinit_three_body
if model_with_new_type_stat is not None
else None,
)
repinit.ntypes = self.ntypes
repformers.ntypes = self.ntypes
repinit.reinit_exclude(self.exclude_types)
Expand All @@ -627,6 +724,11 @@ def change_type_map(
repinit["dstd"] = repinit["dstd"][remap_index]
repformers["davg"] = repformers["davg"][remap_index]
repformers["dstd"] = repformers["dstd"][remap_index]
if self.use_three_body:
repinit_three_body.ntypes = self.ntypes
repinit_three_body.reinit_exclude(self.exclude_types)
repinit_three_body["davg"] = repinit_three_body["davg"][remap_index]
repinit_three_body["dstd"] = repinit_three_body["dstd"][remap_index]

@property
def dim_out(self):
Expand Down Expand Up @@ -695,14 +797,15 @@ def call(
The smooth switch function. shape: nf x nloc x nnei
"""
use_three_body = self.use_three_body
nframes, nloc, nnei = nlist.shape
nall = coord_ext.reshape(nframes, -1).shape[1] // 3
# nlists
nlist_dict = build_multiple_neighbor_list(
coord_ext,
nlist,
[self.repformers.get_rcut(), self.repinit.get_rcut()],
[self.repformers.get_nsel(), self.repinit.get_nsel()],
self.rcut_list,
self.nsel_list,
)
# repinit
g1_ext = self.type_embedding.call()[atype_ext]
Expand All @@ -716,6 +819,21 @@ def call(
g1_ext,
mapping,
)
if use_three_body:
assert self.repinit_three_body is not None
g1_three_body, __, __, __, __ = self.repinit_three_body(
nlist_dict[
get_multiple_nlist_key(
self.repinit_three_body.get_rcut(),
self.repinit_three_body.get_nsel(),
)
],
coord_ext,
atype_ext,
g1_ext,
mapping,
)
g1 = np.concatenate([g1, g1_three_body], axis=-1)
# linear to change shape
g1 = self.g1_shape_tranform(g1)
if self.add_tebd_to_repinit_out:
Expand Down Expand Up @@ -744,6 +862,7 @@ def call(
def serialize(self) -> dict:
repinit = self.repinit
repformers = self.repformers
repinit_three_body = self.repinit_three_body
data = {
"@class": "Descriptor",
"type": "dpa2",
Expand Down Expand Up @@ -797,6 +916,28 @@ def serialize(self) -> dict:
"repformers_variable": repformers_variable,
}
)
if self.use_three_body:
repinit_three_body_variable = {
"embeddings": repinit_three_body.embeddings.serialize(),
"env_mat": EnvMat(
repinit_three_body.rcut, repinit_three_body.rcut_smth
).serialize(),
"@variables": {
"davg": repinit_three_body["davg"],
"dstd": repinit_three_body["dstd"],
},
}
if repinit_three_body.tebd_input_mode in ["strip"]:
repinit_three_body_variable.update(
{
"embeddings_strip": repinit_three_body.embeddings_strip.serialize()
}
)
data.update(
{
"repinit_three_body_variable": repinit_three_body_variable,
}
)
return data

@classmethod
Expand All @@ -807,6 +948,11 @@ def deserialize(cls, data: dict) -> "DescrptDPA2":
data.pop("type")
repinit_variable = data.pop("repinit_variable").copy()
repformers_variable = data.pop("repformers_variable").copy()
repinit_three_body_variable = (
data.pop("repinit_three_body_variable").copy()
if "repinit_three_body_variable" in data
else None
)
type_embedding = data.pop("type_embedding")
g1_shape_tranform = data.pop("g1_shape_tranform")
tebd_transform = data.pop("tebd_transform", None)
Expand Down Expand Up @@ -838,6 +984,21 @@ def deserialize(cls, data: dict) -> "DescrptDPA2":
obj.repinit["davg"] = statistic_repinit["davg"]
obj.repinit["dstd"] = statistic_repinit["dstd"]

if data["repinit"].use_three_body:
# deserialize repinit_three_body
statistic_repinit_three_body = repinit_three_body_variable.pop("@variables")
env_mat = repinit_three_body_variable.pop("env_mat")
tebd_input_mode = data["repinit"].tebd_input_mode
obj.repinit_three_body.embeddings = NetworkCollection.deserialize(
repinit_three_body_variable.pop("embeddings")
)
if tebd_input_mode in ["strip"]:
obj.repinit_three_body.embeddings_strip = NetworkCollection.deserialize(
repinit_three_body_variable.pop("embeddings_strip")
)
obj.repinit_three_body["davg"] = statistic_repinit_three_body["davg"]
obj.repinit_three_body["dstd"] = statistic_repinit_three_body["dstd"]

# deserialize repformers
statistic_repformers = repformers_variable.pop("@variables")
env_mat = repformers_variable.pop("env_mat")
Expand Down
Loading

0 comments on commit 5979086

Please sign in to comment.