Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/devel' into avoid-deepcopy
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz committed Nov 8, 2024
2 parents ba498f9 + 6c66be9 commit 5d885c7
Show file tree
Hide file tree
Showing 16 changed files with 413 additions and 63 deletions.
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DescrptHybrid(BaseDescriptor, NativeOP):
def __init__(
self,
list: list[Union[BaseDescriptor, dict[str, Any]]],
type_map: Optional[list[str]] = None,
ntypes: Optional[int] = None, # to be compat with input
) -> None:
super().__init__()
# warning: list is conflict with built-in list
Expand All @@ -56,6 +58,10 @@ def __init__(
if isinstance(ii, BaseDescriptor):
formatted_descript_list.append(ii)
elif isinstance(ii, dict):
ii = ii.copy()
# only pass if not already set
ii.setdefault("type_map", type_map)
ii.setdefault("ntypes", ntypes)
formatted_descript_list.append(BaseDescriptor(**ii))
else:
raise NotImplementedError
Expand Down
19 changes: 11 additions & 8 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,7 +1792,7 @@ def serialize(self) -> dict:
"""
data = {
"@class": "RepformerLayer",
"@version": 1,
"@version": 2,
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
Expand Down Expand Up @@ -1877,9 +1877,11 @@ def serialize(self) -> dict:
if self.update_style == "res_residual":
data.update(
{
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
"@variables": {
"g1_residual": [to_numpy_array(aa) for aa in self.g1_residual],
"g2_residual": [to_numpy_array(aa) for aa in self.g2_residual],
"h2_residual": [to_numpy_array(aa) for aa in self.h2_residual],
}
}
)
return data
Expand All @@ -1894,7 +1896,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
The dict to deserialize from.
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
linear1 = data.pop("linear1")
update_chnnl_2 = data["update_chnnl_2"]
Expand All @@ -1915,9 +1917,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
attn2_ev_apply = data.pop("attn2_ev_apply", None)
loc_attn = data.pop("loc_attn", None)
g1_self_mlp = data.pop("g1_self_mlp", None)
g1_residual = data.pop("g1_residual", [])
g2_residual = data.pop("g2_residual", [])
h2_residual = data.pop("h2_residual", [])
variables = data.pop("@variables", {})
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))

obj = cls(**data)
obj.linear1 = NativeLayer.deserialize(linear1)
Expand Down
14 changes: 4 additions & 10 deletions deepmd/dpmodel/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.descriptor.se_e2_a import (
DescrptSeA,
)
from deepmd.dpmodel.fitting.ener_fitting import (
EnergyFittingNet,
)
Expand Down Expand Up @@ -39,16 +36,13 @@ def get_standard_model(data: dict) -> EnergyModel:
data : dict
The data to construct the model.
"""
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
if descriptor_type == "se_e2_a":
descriptor = DescrptSeA(
**data["descriptor"],
)
else:
raise ValueError(f"Unknown descriptor type {descriptor_type}")
descriptor = BaseDescriptor(
**data["descriptor"],
)
if fitting_type == "ener":
fitting = EnergyFittingNet(
ntypes=descriptor.get_ntypes(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/jax/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from deepmd.jax.descriptor.hybrid import (
DescrptHybrid,
)
from deepmd.jax.descriptor.se_atten_v2 import (
DescrptSeAttenV2,
)
from deepmd.jax.descriptor.se_e2_a import (
DescrptSeA,
)
Expand All @@ -27,6 +30,7 @@
"DescrptSeT",
"DescrptSeTTebd",
"DescrptDPA1",
"DescrptSeAttenV2",
"DescrptDPA2",
"DescrptHybrid",
]
1 change: 1 addition & 0 deletions deepmd/jax/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_standard_model(data: dict):
data = deepcopy(data)
descriptor_type = data["descriptor"].pop("type")
data["descriptor"]["type_map"] = data["type_map"]
data["descriptor"]["ntypes"] = len(data["type_map"])
fitting_type = data["fitting_net"].pop("type")
data["fitting_net"]["type_map"] = data["type_map"]
descriptor = BaseDescriptor.get_class_by_type(descriptor_type)(
Expand Down
19 changes: 11 additions & 8 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ def serialize(self) -> dict:
"""
data = {
"@class": "RepformerLayer",
"@version": 1,
"@version": 2,
"rcut": self.rcut,
"rcut_smth": self.rcut_smth,
"sel": self.sel,
Expand Down Expand Up @@ -1380,9 +1380,11 @@ def serialize(self) -> dict:
if self.update_style == "res_residual":
data.update(
{
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
"@variables": {
"g1_residual": [to_numpy_array(t) for t in self.g1_residual],
"g2_residual": [to_numpy_array(t) for t in self.g2_residual],
"h2_residual": [to_numpy_array(t) for t in self.h2_residual],
}
}
)
return data
Expand All @@ -1397,7 +1399,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
The dict to deserialize from.
"""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
check_version_compatibility(data.pop("@version"), 2, 1)
data.pop("@class")
linear1 = data.pop("linear1")
update_chnnl_2 = data["update_chnnl_2"]
Expand All @@ -1418,9 +1420,10 @@ def deserialize(cls, data: dict) -> "RepformerLayer":
attn2_ev_apply = data.pop("attn2_ev_apply", None)
loc_attn = data.pop("loc_attn", None)
g1_self_mlp = data.pop("g1_self_mlp", None)
g1_residual = data.pop("g1_residual", [])
g2_residual = data.pop("g2_residual", [])
h2_residual = data.pop("h2_residual", [])
variables = data.pop("@variables", {})
g1_residual = variables.get("g1_residual", data.pop("g1_residual", []))
g2_residual = variables.get("g2_residual", data.pop("g2_residual", []))
h2_residual = variables.get("h2_residual", data.pop("h2_residual", []))

obj = cls(**data)
obj.linear1 = MLPLayer.deserialize(linear1)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pt/model/task/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ def _forward_common(

if nd != self.dim_descrpt:
raise ValueError(
"get an input descriptor of dim {nd},"
"which is not consistent with {self.dim_descrpt}."
f"get an input descriptor of dim {nd},"
f"which is not consistent with {self.dim_descrpt}."
)
# check fparam dim, concate to input descriptor
if self.numb_fparam > 0:
Expand Down
6 changes: 6 additions & 0 deletions deepmd/tf/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def update_sel(
return local_jdata_cpy, min_nbor_dist

def serialize(self, suffix: str = "") -> dict:
if hasattr(self, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")
return {
"@class": "Descriptor",
"type": "hybrid",
Expand All @@ -485,4 +487,8 @@ def deserialize(cls, data: dict, suffix: str = "") -> "DescrptHybrid":
for idx, ii in enumerate(data["list"])
],
)
# search for type embedding
for ii in obj.descrpt_list:
if hasattr(ii, "type_embedding"):
raise NotImplementedError("hybrid + type embedding is not supported")
return obj
Loading

0 comments on commit 5d885c7

Please sign in to comment.