Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: consistent DPA-1 model #4320

Merged
merged 6 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, NativeOP):
def __init__(
self,
list: list[Union[BaseDescriptor, dict[str, Any]]],
type_map: Optional[list[str]] = None,
ntypes: Optional[int] = None, # to be compat with input
) -> None:
super().__init__()
# warning: list is conflict with built-in list
Expand All @@ -56,7 +58,9 @@ def __init__(
if isinstance(ii, BaseDescriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
formatted_descript_list.append(BaseDescriptor(**ii))
formatted_descript_list.append(
BaseDescriptor(**ii, type_map=type_map, ntypes=ntypes)
)
else:
raise NotImplementedError
self.descrpt_list = formatted_descript_list
Expand Down
14 changes: 4 additions & 10 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
Expand Down Expand Up @@ -39,16 +36,13 @@ def get_standard_model(data: dict) -> EnergyModel:
data : dict
The data to construct the model.
"""
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
if descriptor_type == "se_e2_a":
descriptor = DescrptSeA(
**data["descriptor"],
)
else:
raise ValueError(f"Unknown descriptor type {descriptor_type}")
descriptor = BaseDescriptor(
**data["descriptor"],
)
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.jax.descriptor.se_atten_v2 import (
DescrptSeAttenV2,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -27,6 +30,7 @@
"DescrptSeT",
"DescrptSeTTebd",
"DescrptDPA1",
"DescrptSeAttenV2",
"DescrptDPA2",
"DescrptHybrid",
]
1 change: 1 addition & 0 deletions deepmd/jax/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_standard_model(data: dict):
data = deepcopy(data)
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
descriptor = BaseDescriptor.get_class_by_type(descriptor_type)(
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ def _forward_common(

if nd != self.dim_descrpt:
raise ValueError(
"get an input descriptor of dim {nd},"
"which is not consistent with {self.dim_descrpt}."
f"get an input descriptor of dim {nd},"
f"which is not consistent with {self.dim_descrpt}."
)
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
Expand Down
6 changes: 6 additions & 0 deletions deepmd/tf/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def update_sel(
return local_jdata_cpy, min_nbor_dist

def serialize(self, suffix: str = "") -> dict:
if hasattr(self, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")
return {
"@class": "Descriptor",
"type": "hybrid",
Expand All @@ -485,4 +487,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid":
for idx, ii in enumerate(data["list"])
],
)
# search for type embedding
for ii in obj.descrpt_list:
if hasattr(ii, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")
njzjz marked this conversation as resolved.
Show resolved Hide resolved
return obj
117 changes: 89 additions & 28 deletions deepmd/tf/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,11 +219,11 @@ def __init__(
if scaling_factor != 1.0:
raise NotImplementedError("scaling_factor is not supported.")
if not normalize:
raise NotImplementedError("normalize is not supported.")
raise NotImplementedError("Disabling normalize is not supported.")
if temperature is not None:
raise NotImplementedError("temperature is not supported.")
if not concat_output_tebd:
raise NotImplementedError("concat_output_tebd is not supported.")
raise NotImplementedError("Disbaling concat_output_tebd is not supported.")
njzjz marked this conversation as resolved.
Show resolved Hide resolved
if env_protection != 0.0:
raise NotImplementedError("env_protection != 0.0 is not supported.")
# to keep consistent with default value in this backends
Expand Down Expand Up @@ -1866,7 +1866,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeAtten:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
data.pop("type")
embedding_net_variables = cls.deserialize_network(
Expand All @@ -1878,10 +1878,13 @@ def deserialize(cls, data: dict, suffix: str = ""):
data.pop("env_mat")
variables = data.pop("@variables")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
raise ValueError(
"Deserialization is unsupported for `tebd_input_mode='strip'` in the native model."
)
type_embedding = TypeEmbedNet.deserialize(
data.pop("type_embedding"), suffix=suffix
)
if "use_tebd_bias" not in data:
# v1 compatibility
data["use_tebd_bias"] = True
type_embedding.use_tebd_bias = data.pop("use_tebd_bias")
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.attention_layer_variables = attention_layer_variables
Expand All @@ -1891,6 +1894,17 @@ def deserialize(cls, data: dict, suffix: str = ""):
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.type_embedding = type_embedding
if tebd_input_mode in ["strip"]:
type_one_side = data["type_one_side"]
two_side_embeeding_net_variables = cls.deserialize_network_strip(
data.pop("embeddings_strip"),
suffix=suffix,
type_one_side=type_one_side,
)
descriptor.two_side_embeeding_net_variables = (
two_side_embeeding_net_variables
)
return descriptor

def serialize(self, suffix: str = "") -> dict:
Expand All @@ -1906,11 +1920,6 @@ def serialize(self, suffix: str = "") -> dict:
dict
The serialized data
"""
if self.stripped_type_embedding and type(self) is DescrptSeAtten:
# only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'
raise NotImplementedError(
"serialization is unsupported by the native model when tebd_input_mode=='strip'"
)
# todo support serialization when tebd_input_mode=='strip' and type_one_side is True
if self.stripped_type_embedding and self.type_one_side:
raise NotImplementedError(
Expand All @@ -1927,10 +1936,18 @@ def serialize(self, suffix: str = "") -> dict:
assert self.davg is not None
assert self.dstd is not None

tebd_dim = self.type_embedding.neuron[0]
if self.tebd_input_mode in ["concat"]:
if not self.type_one_side:
embd_input_dim = 1 + tebd_dim * 2
else:
embd_input_dim = 1 + tebd_dim
else:
embd_input_dim = 1
data = {
"@class": "Descriptor",
"type": "se_atten",
"@version": 1,
"type": "dpa1",
"@version": 2,
"rcut": self.rcut_r,
"rcut_smth": self.rcut_r_smth,
"sel": self.sel_a,
Expand All @@ -1952,9 +1969,7 @@ def serialize(self, suffix: str = "") -> dict:
"embeddings": self.serialize_network(
ntypes=self.ntypes,
ndim=0,
in_dim=1
if not hasattr(self, "embd_input_dim")
else self.embd_input_dim,
in_dim=embd_input_dim,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
Expand Down Expand Up @@ -1986,17 +2001,23 @@ def serialize(self, suffix: str = "") -> dict:
"type_one_side": self.type_one_side,
"spin": self.spin,
}
data["type_embedding"] = self.type_embedding.serialize(suffix=suffix)
data["use_tebd_bias"] = self.type_embedding.use_tebd_bias
data["tebd_dim"] = tebd_dim
if len(self.type_embedding.neuron) > 1:
raise NotImplementedError(
"Only support single layer type embedding network"
)
if self.tebd_input_mode in ["strip"]:
assert (
type(self) is not DescrptSeAtten
), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'"
# assert (
# type(self) is not DescrptSeAtten
# ), "only DescrptDPA1Compat and DescrptSeAttenV2 can serialize when tebd_input_mode=='strip'"
data.update(
{
"embeddings_strip": self.serialize_network_strip(
ntypes=self.ntypes,
ndim=0,
in_dim=2
* self.tebd_dim, # only DescrptDPA1Compat has this attribute
in_dim=2 * tebd_dim,
neuron=self.filter_neuron,
activation_function=self.activation_function_name,
resnet_dt=self.filter_resnet_dt,
Expand All @@ -2006,8 +2027,54 @@ def serialize(self, suffix: str = "") -> dict:
)
}
)
# default values
data.update(
{
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
"concat_output_tebd": True,
"use_econf_tebd": False,
}
)
data["attention_layers"] = self.update_attention_layers_serialize(
data["attention_layers"]
)
return data

def update_attention_layers_serialize(self, data: dict):
"""Update the serialized data to be consistent with other backend references."""
new_dict = {
"@class": "NeighborGatedAttention",
"@version": 1,
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
}
new_dict.update(data)
update_info = {
"nnei": self.nnei_a,
"embed_dim": self.filter_neuron[-1],
"hidden_dim": self.att_n,
"dotr": self.attn_dotr,
"do_mask": self.attn_mask,
"scaling_factor": 1.0,
"normalize": True,
"temperature": None,
"precision": self.filter_precision.name,
}
for layer_idx in range(self.attn_layer):
new_dict["attention_layers"][layer_idx].update(update_info)
new_dict["attention_layers"][layer_idx]["attention_layer"].update(
update_info
)
new_dict["attention_layers"][layer_idx]["attention_layer"].update(
{
"num_heads": 1,
}
)
return new_dict


class DescrptDPA1Compat(DescrptSeAtten):
r"""Consistent version of the model for testing with other backend references.
Expand Down Expand Up @@ -2433,17 +2500,11 @@ def serialize(self, suffix: str = "") -> dict:
{
"type": "dpa1",
"@version": 2,
"tebd_dim": self.tebd_dim,
"scaling_factor": self.scaling_factor,
"normalize": self.normalize,
"temperature": self.temperature,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"use_tebd_bias": self.use_tebd_bias,
"type_embedding": self.type_embedding.serialize(suffix),
}
)
data["attention_layers"] = self.update_attention_layers_serialize(
data["attention_layers"]
)
return data
13 changes: 12 additions & 1 deletion deepmd/tf/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Optional,
)

from deepmd.tf.utils.type_embed import (
TypeEmbedNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)
Expand Down Expand Up @@ -130,7 +133,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
if cls is not DescrptSeAttenV2:
raise NotImplementedError(f"Not implemented in class {cls.__name__}")
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
data.pop("type")
embedding_net_variables = cls.deserialize_network(
Expand All @@ -147,6 +150,13 @@ def deserialize(cls, data: dict, suffix: str = ""):
suffix=suffix,
type_one_side=type_one_side,
)
type_embedding = TypeEmbedNet.deserialize(
data.pop("type_embedding"), suffix=suffix
)
njzjz marked this conversation as resolved.
Show resolved Hide resolved
if "use_tebd_bias" not in data:
# v1 compatibility
data["use_tebd_bias"] = True
type_embedding.use_tebd_bias = data.pop("use_tebd_bias")
descriptor = cls(**data)
descriptor.embedding_net_variables = embedding_net_variables
descriptor.attention_layer_variables = attention_layer_variables
Expand All @@ -157,6 +167,7 @@ def deserialize(cls, data: dict, suffix: str = ""):
descriptor.dstd = variables["dstd"].reshape(
descriptor.ntypes, descriptor.ndescrpt
)
descriptor.type_embedding = type_embedding
return descriptor

def serialize(self, suffix: str = "") -> dict:
Expand Down
6 changes: 5 additions & 1 deletion deepmd/tf/fit/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
len(self.layer_name) == len(self.n_neuron) + 1
), "length of layer_name should be that of n_neuron + 1"
self.mixed_types = mixed_types
self.tebd_dim = 0

def get_numb_fparam(self) -> int:
"""Get the number of frame parameters."""
Expand Down Expand Up @@ -754,6 +755,8 @@ def build(
outs = tf.reshape(outs, [-1])

tf.summary.histogram("fitting_net_output", outs)
# recover original dim_descrpt, which needs to be serialized
self.dim_descrpt = original_dim_descrpt
return tf.reshape(outs, [-1])

def init_variables(
Expand Down Expand Up @@ -908,7 +911,7 @@ def serialize(self, suffix: str = "") -> dict:
"@version": 2,
"var_name": "energy",
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
"dim_descrpt": self.dim_descrpt + self.tebd_dim,
"mixed_types": self.mixed_types,
"dim_out": 1,
"neuron": self.n_neuron,
Expand All @@ -930,6 +933,7 @@ def serialize(self, suffix: str = "") -> dict:
ndim=0 if self.mixed_types else 1,
in_dim=(
self.dim_descrpt
+ self.tebd_dim
+ self.numb_fparam
+ (0 if self.use_aparam_as_mask else self.numb_aparam)
),
Expand Down
Loading
Loading