Skip to content

Commit

Permalink
Merge branch 'devel' into njzjz-patch-36
Browse files Browse the repository at this point in the history
  • Loading branch information
njzjz authored Jun 7, 2024
2 parents 3aab6f9 + 674bad7 commit 325d32c
Show file tree
Hide file tree
Showing 108 changed files with 1,751 additions and 261 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
exclude: ^source/3rdparty
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.5
rev: v0.4.7
hooks:
- id: ruff
args: ["--fix"]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def forward_atomic(
self,
extended_coord: np.ndarray,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return any(model.has_message_passing() for model in self.models)

def get_rcut(self) -> float:
"""Get the cut-off radius."""
return max(self.get_model_rcuts())
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/make_base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def fwd(
self,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def mixed_types(self) -> bool:
# to match DPA1 and DPA2.
return True

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return False

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from .make_base_descriptor import (
make_base_descriptor,
)
from .se_atten_v2 import (
DescrptSeAttenV2,
)
from .se_e2_a import (
DescrptSeA,
)
Expand All @@ -26,6 +29,7 @@
"DescrptSeR",
"DescrptSeT",
"DescrptDPA1",
"DescrptSeAttenV2",
"DescrptDPA2",
"DescrptHybrid",
"make_base_descriptor",
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ def call(
):
"""Calculate DescriptorBlock."""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
8 changes: 8 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,10 @@ def mixed_types(self) -> bool:
"""
return self.se_atten.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return self.se_atten.has_message_passing()

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.se_atten.get_env_protection()
Expand Down Expand Up @@ -886,6 +890,10 @@ def call(
sw,
)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return False


class NeighborGatedAttention(NativeOP):
def __init__(
Expand Down
6 changes: 6 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,12 @@ def mixed_types(self) -> bool:
"""
return True

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(
[self.repinit.has_message_passing(), self.repformers.has_message_passing()]
)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ def mixed_types(self):
"""
return any(descrpt.mixed_types() for descrpt in self.descrpt_list)

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return any(descrpt.has_message_passing() for descrpt in self.descrpt_list)

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix. All descriptors should be the same."""
all_protection = [descrpt.get_env_protection() for descrpt in self.descrpt_list]
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
pass

@abstractmethod
def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""

@abstractmethod
def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def call(
rot_mat = np.transpose(h2g2, (0, 1, 3, 2))
return g1, g2, h2, rot_mat.reshape(-1, nloc, self.dim_emb, 3), sw

def has_message_passing(self) -> bool:
"""Returns whether the descriptor block has message passing."""
return True


# translated by GPT and modified
def get_residual(
Expand Down
180 changes: 180 additions & 0 deletions deepmd/dpmodel/descriptor/se_atten_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Any,
List,
Optional,
Tuple,
Union,
)

import numpy as np

from deepmd.dpmodel import (
DEFAULT_PRECISION,
PRECISION_DICT,
)
from deepmd.dpmodel.utils import (
NetworkCollection,
)
from deepmd.dpmodel.utils.type_embed import (
TypeEmbedNet,
)
from deepmd.utils.version import (
check_version_compatibility,
)

from .base_descriptor import (
BaseDescriptor,
)
from .dpa1 import (
DescrptDPA1,
NeighborGatedAttention,
)


@BaseDescriptor.register("se_atten_v2")
class DescrptSeAttenV2(DescrptDPA1):
def __init__(
self,
rcut: float,
rcut_smth: float,
sel: Union[List[int], int],
ntypes: int,
neuron: List[int] = [25, 50, 100],
axis_neuron: int = 8,
tebd_dim: int = 8,
resnet_dt: bool = False,
trainable: bool = True,
type_one_side: bool = False,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
attn_mask: bool = False,
exclude_types: List[Tuple[int, int]] = [],
env_protection: float = 0.0,
set_davg_zero: bool = False,
activation_function: str = "tanh",
precision: str = DEFAULT_PRECISION,
scaling_factor=1.0,
normalize: bool = True,
temperature: Optional[float] = None,
trainable_ln: bool = True,
ln_eps: Optional[float] = 1e-5,
concat_output_tebd: bool = True,
spin: Optional[Any] = None,
stripped_type_embedding: Optional[bool] = None,
use_econf_tebd: bool = False,
type_map: Optional[List[str]] = None,
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
DescrptDPA1.__init__(
self,
rcut,
rcut_smth,
sel,
ntypes,
neuron=neuron,
axis_neuron=axis_neuron,
tebd_dim=tebd_dim,
tebd_input_mode="strip",
resnet_dt=resnet_dt,
trainable=trainable,
type_one_side=type_one_side,
attn=attn,
attn_layer=attn_layer,
attn_dotr=attn_dotr,
attn_mask=attn_mask,
exclude_types=exclude_types,
env_protection=env_protection,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
scaling_factor=scaling_factor,
normalize=normalize,
temperature=temperature,
trainable_ln=trainable_ln,
ln_eps=ln_eps,
smooth_type_embedding=True,
concat_output_tebd=concat_output_tebd,
spin=spin,
stripped_type_embedding=stripped_type_embedding,
use_econf_tebd=use_econf_tebd,
type_map=type_map,
# consistent with argcheck, not used though
seed=seed,
)

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
obj = self.se_atten
data = {
"@class": "Descriptor",
"type": "se_atten_v2",
"@version": 1,
"rcut": obj.rcut,
"rcut_smth": obj.rcut_smth,
"sel": obj.sel,
"ntypes": obj.ntypes,
"neuron": obj.neuron,
"axis_neuron": obj.axis_neuron,
"tebd_dim": obj.tebd_dim,
"set_davg_zero": obj.set_davg_zero,
"attn": obj.attn,
"attn_layer": obj.attn_layer,
"attn_dotr": obj.attn_dotr,
"attn_mask": False,
"activation_function": obj.activation_function,
"resnet_dt": obj.resnet_dt,
"scaling_factor": obj.scaling_factor,
"normalize": obj.normalize,
"temperature": obj.temperature,
"trainable_ln": obj.trainable_ln,
"ln_eps": obj.ln_eps,
"type_one_side": obj.type_one_side,
"concat_output_tebd": self.concat_output_tebd,
"use_econf_tebd": self.use_econf_tebd,
"type_map": self.type_map,
# make deterministic
"precision": np.dtype(PRECISION_DICT[obj.precision]).name,
"embeddings": obj.embeddings.serialize(),
"embeddings_strip": obj.embeddings_strip.serialize(),
"attention_layers": obj.dpa1_attention.serialize(),
"env_mat": obj.env_mat.serialize(),
"type_embedding": self.type_embedding.serialize(),
"exclude_types": obj.exclude_types,
"env_protection": obj.env_protection,
"@variables": {
"davg": obj["davg"],
"dstd": obj["dstd"],
},
## to be updated when the options are supported.
"trainable": self.trainable,
"spin": None,
}
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptSeAttenV2":
"""Deserialize from dict."""
data = data.copy()
check_version_compatibility(data.pop("@version"), 1, 1)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
embeddings = data.pop("embeddings")
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
data.pop("env_mat")
embeddings_strip = data.pop("embeddings_strip")
obj = cls(**data)

obj.se_atten["davg"] = variables["davg"]
obj.se_atten["dstd"] = variables["dstd"]
obj.se_atten.embeddings = NetworkCollection.deserialize(embeddings)
obj.se_atten.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
return obj
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ def mixed_types(self):
"""
return False

def has_message_passing(self) -> bool:
"""Returns whether the descriptor has message passing."""
return False

def get_env_protection(self) -> float:
"""Returns the protection of building environment matrix."""
return self.env_protection
Expand Down
4 changes: 4 additions & 0 deletions deepmd/dpmodel/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def mixed_types(self) -> bool:
"""
return self.atomic_model.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the model has message passing."""
return self.atomic_model.has_message_passing()

def atomic_output_def(self) -> FittingOutputDef:
"""Get the output def of the atomic model."""
return self.atomic_model.atomic_output_def()
Expand Down
5 changes: 1 addition & 4 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,10 +303,7 @@ def train(FLAGS):

def freeze(FLAGS):
model = torch.jit.script(inference.Tester(FLAGS.model, head=FLAGS.head).model)
if '"type": "dpa2"' in model.get_model_def_script():
extra_files = {"type": "dpa2"}
else:
extra_files = {"type": "else"}
extra_files = {}
torch.jit.save(
model,
FLAGS.output,
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def mixed_types(self) -> bool:
"""
return self.descriptor.mixed_types()

def has_message_passing(self) -> bool:
"""Returns whether the atomic model has message passing."""
return self.descriptor.has_message_passing()

def serialize(self) -> dict:
dd = BaseAtomicModel.serialize(self)
dd.update(
Expand Down
Loading

0 comments on commit 325d32c

Please sign in to comment.