Skip to content

Commit

Permalink
feat: Support stripped type embedding in DPA1 of PT/DP (deepmodeling#…
Browse files Browse the repository at this point in the history
…3712)

This PR supports stripped type embedding in DPA1 of PT/DP:

- Remove `stripped_type_embedding` params in all classes and use
`tebd_input_mode` == "strip" instead.
- Add stripped type embedding inplementation for DPA1 of PT/DP.
- Add serialize and deserialize for stripped type embedding.

Note: 
- Old TF inplementation has not consistent behaivior when
`type_one_side`==True and `tebd_input_mode` == "strip", it always uses
two_side type stripped embeddings input, which is also inconsistent with
`DescrptSeAEbdV2` in TF (but the training still works and only raise
`NotImplementedError` when doing serialization now) may need support
from @nahso .
- Old TF inplementation `init_variables` will not init `idt` weights
from graph for `two_side_embeeding_net_variables` (fixed), I'm surprised
that no ut failed before (maybe all tests use `resnet_dt` == False).
- The TF implementation of `DescrptSeAtten` does not support
serialization when `tebd_input_mode` == "strip". This limitation arises
because the shape of `type_embedding` cannot be determined after init,
as it is decided at runtime. While the consistent version
`DescrptDPA1Compat` is compatible with this configuration.


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

- **New Features**
- Enhanced model flexibility with new type embedding input modes:
`concat` and `strip`.
- **Bug Fixes**
- Improved model compression logic alignment with new type embedding
modes for more efficient operations.
- **Documentation**
- Updated documentation to explain the impact of new type embedding
input modes on model descriptors.
- **Tests**
- Adjusted test cases to reflect changes in type embedding input modes
for robust testing.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Duo <50307526+iProzd@users.noreply.github.com>
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and Mathieu Taillefumier committed Sep 18, 2024
1 parent 4181a4b commit f734015
Show file tree
Hide file tree
Showing 21 changed files with 462 additions and 135 deletions.
92 changes: 75 additions & 17 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,9 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The way to mix the type embeddings. Supported options are `concat`.
(TODO need to support stripped_type_embedding option)
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
Expand Down Expand Up @@ -182,16 +183,19 @@ class DescrptDPA1(NativeOP, BaseDescriptor):
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separate embedding network.
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
The old implementation of deepspin.
Limitations
-----------
The currently implementation does not support the following features
1. tebd_input_mode != 'concat'
The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
Expand Down Expand Up @@ -233,19 +237,21 @@ def __init__(
smooth_type_embedding: bool = True,
concat_output_tebd: bool = True,
spin: Optional[Any] = None,
stripped_type_embedding: Optional[bool] = None,
# consistent with argcheck, not used though
seed: Optional[int] = None,
) -> None:
## seed, uniform_seed, multi_task, not included.
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# TODO
if tebd_input_mode != "concat":
raise NotImplementedError("tebd_input_mode != 'concat' not implemented")
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -290,25 +296,38 @@ def __init__(
activation_function="Linear",
precision=precision,
)
self.tebd_dim_input = self.tebd_dim if self.type_one_side else self.tebd_dim * 2
if self.tebd_input_mode in ["concat"]:
if not self.type_one_side:
in_dim = 1 + self.tebd_dim * 2
else:
in_dim = 1 + self.tebd_dim
self.embd_input_dim = 1 + self.tebd_dim_input
else:
in_dim = 1
self.embd_input_dim = 1
self.embeddings = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings[0] = EmbeddingNet(
in_dim,
self.embd_input_dim,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
if self.tebd_input_mode in ["strip"]:
self.embeddings_strip = NetworkCollection(
ndim=0,
ntypes=self.ntypes,
network_type="embedding_network",
)
self.embeddings_strip[0] = EmbeddingNet(
self.tebd_dim_input,
self.neuron,
self.activation_function,
self.resnet_dt,
self.precision,
)
else:
self.embeddings_strip = None
self.dpa1_attention = NeighborGatedAttention(
self.attn_layer,
self.nnei,
Expand Down Expand Up @@ -410,6 +429,18 @@ def cal_g(
gg = self.embeddings[embedding_idx].call(ss)
return gg

def cal_g_strip(
self,
ss,
embedding_idx,
):
assert self.embeddings_strip is not None
nfnl, nnei = ss.shape[0:2]
ss = ss.reshape(nfnl, nnei, -1)
# nfnl x nnei x ng
gg = self.embeddings_strip[embedding_idx].call(ss)
return gg

def reinit_exclude(
self,
exclude_types: List[Tuple[int, int]] = [],
Expand Down Expand Up @@ -500,11 +531,28 @@ def call(
else:
# nfnl x nnei x (1 + tebd_dim)
ss = np.concatenate([ss, atype_embd_nlist], axis=-1)
# calculate gg
# nfnl x nnei x ng
gg = self.cal_g(ss, 0)
elif self.tebd_input_mode in ["strip"]:
# nfnl x nnei x ng
gg_s = self.cal_g(ss, 0)
assert self.embeddings_strip is not None
if not self.type_one_side:
# nfnl x nnei x (tebd_dim * 2)
tt = np.concatenate([atype_embd_nlist, atype_embd_nnei], axis=-1)
else:
# nfnl x nnei x tebd_dim
tt = atype_embd_nlist
# nfnl x nnei x ng
gg_t = self.cal_g_strip(tt, 0)
if self.smooth:
gg_t = gg_t * sw.reshape(-1, self.nnei, 1)
# nfnl x nnei x ng
gg = gg_s * gg_t + gg_s
else:
raise NotImplementedError

# calculate gg
gg = self.cal_g(ss, 0)
input_r = dmatrix.reshape(-1, nnei, 4)[:, :, 1:4] / np.maximum(
np.linalg.norm(
dmatrix.reshape(-1, nnei, 4)[:, :, 1:4], axis=-1, keepdims=True
Expand Down Expand Up @@ -532,7 +580,7 @@ def call(

def serialize(self) -> dict:
"""Serialize the descriptor to dict."""
return {
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
Expand Down Expand Up @@ -575,6 +623,9 @@ def serialize(self) -> dict:
"trainable": True,
"spin": None,
}
if self.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": self.embeddings_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
Expand All @@ -588,11 +639,18 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)

obj["davg"] = variables["davg"]
obj["dstd"] = variables["dstd"]
obj.embeddings = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.embeddings_strip = NetworkCollection.deserialize(embeddings_strip)
obj.type_embedding = TypeEmbedNet.deserialize(type_embedding)
obj.dpa1_attention = NeighborGatedAttention.deserialize(attention_layers)
return obj
Expand Down
42 changes: 28 additions & 14 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module):
tebd_dim: int
Dimension of the type embedding
tebd_input_mode: str
The way to mix the type embeddings. Supported options are `concat`.
(TODO need to support stripped_type_embedding option)
The input mode of the type embedding. Supported modes are ["concat", "strip"].
- "concat": Concatenate the type embedding with the smoothed radial information as the union input for the embedding network.
- "strip": Use a separated embedding network for the type embedding and combine the output with the radial embedding network output.
resnet_dt: bool
Time-step `dt` in the resnet construction:
y = x + dt * \phi (Wx + b)
Expand Down Expand Up @@ -165,16 +166,19 @@ class DescrptDPA1(BaseDescriptor, torch.nn.Module):
Whether to use smooth process in attention weights calculation.
concat_output_tebd: bool
Whether to concat type embedding at the output of the descriptor.
stripped_type_embedding: bool, Optional
(Deprecated, kept only for compatibility.)
Whether to strip the type embedding into a separate embedding network.
Setting this parameter to `True` is equivalent to setting `tebd_input_mode` to 'strip'.
Setting it to `False` is equivalent to setting `tebd_input_mode` to 'concat'.
The default value is `None`, which means the `tebd_input_mode` setting will be used instead.
spin
(Only support None to keep consistent with other backend references.)
(Not used in this version. Not-none option is not implemented.)
The old implementation of deepspin.
Limitations
-----------
The currently implementation does not support the following features
1. tebd_input_mode != 'concat'
The currently implementation will not support the following deprecated features
1. spin is not None
2. attn_mask == True
Expand All @@ -196,8 +200,7 @@ def __init__(
axis_neuron: int = 16,
tebd_dim: int = 8,
tebd_input_mode: str = "concat",
# set_davg_zero: bool = False,
set_davg_zero: bool = True, # TODO
set_davg_zero: bool = True,
attn: int = 128,
attn_layer: int = 2,
attn_dotr: bool = True,
Expand All @@ -216,25 +219,24 @@ def __init__(
ln_eps: Optional[float] = 1e-5,
smooth_type_embedding: bool = True,
type_one_side: bool = False,
stripped_type_embedding: Optional[bool] = None,
# not implemented
stripped_type_embedding: bool = False,
spin=None,
type: Optional[str] = None,
seed: Optional[int] = None,
old_impl: bool = False,
):
super().__init__()
if stripped_type_embedding:
raise NotImplementedError("stripped_type_embedding is not supported.")
# Ensure compatibility with the deprecated stripped_type_embedding option.
if stripped_type_embedding is not None:
# Use the user-set stripped_type_embedding parameter first
tebd_input_mode = "strip" if stripped_type_embedding else "concat"
if spin is not None:
raise NotImplementedError("old implementation of spin is not supported.")
if attn_mask:
raise NotImplementedError(
"old implementation of attn_mask is not supported."
)
# TODO
if tebd_input_mode != "concat":
raise NotImplementedError("tebd_input_mode != 'concat' not implemented")
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -377,7 +379,7 @@ def set_stat_mean_and_stddev(

def serialize(self) -> dict:
obj = self.se_atten
return {
data = {
"@class": "Descriptor",
"type": "dpa1",
"@version": 1,
Expand Down Expand Up @@ -420,6 +422,9 @@ def serialize(self) -> dict:
"trainable": True,
"spin": None,
}
if obj.tebd_input_mode in ["strip"]:
data.update({"embeddings_strip": obj.filter_layers_strip.serialize()})
return data

@classmethod
def deserialize(cls, data: dict) -> "DescrptDPA1":
Expand All @@ -432,6 +437,11 @@ def deserialize(cls, data: dict) -> "DescrptDPA1":
type_embedding = data.pop("type_embedding")
attention_layers = data.pop("attention_layers")
env_mat = data.pop("env_mat")
tebd_input_mode = data["tebd_input_mode"]
if tebd_input_mode in ["strip"]:
embeddings_strip = data.pop("embeddings_strip")
else:
embeddings_strip = None
obj = cls(**data)

def t_cvt(xx):
Expand All @@ -443,6 +453,10 @@ def t_cvt(xx):
obj.se_atten["davg"] = t_cvt(variables["davg"])
obj.se_atten["dstd"] = t_cvt(variables["dstd"])
obj.se_atten.filter_layers = NetworkCollection.deserialize(embeddings)
if tebd_input_mode in ["strip"]:
obj.se_atten.filter_layers_strip = NetworkCollection.deserialize(
embeddings_strip
)
obj.se_atten.dpa1_attention = NeighborGatedAttention.deserialize(
attention_layers
)
Expand Down
Loading

0 comments on commit f734015

Please sign in to comment.