Skip to content

Commit

Permalink
update first three subs
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Sep 13, 2024
1 parent 1155251 commit 498fc24
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 40 deletions.
18 changes: 18 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ def __init__(
update_residual_init: str = "norm",
set_davg_zero: bool = True,
trainable_ln: bool = True,
use_sqrt_nnei: bool = False,
g1_out_conv: bool = False,
g1_out_mlp: bool = False,
ln_eps: Optional[float] = 1e-5,
):
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -236,6 +239,12 @@ def __init__(
Set the normalization average to zero.
trainable_ln : bool, optional
Whether to use trainable shift and scale weights in layer normalization.
use_sqrt_nnei : bool, optional
Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly.
g1_out_conv : bool, optional
Whether to put the convolutional update of g1 separately outside the concatenated MLP update.
g1_out_mlp : bool, optional
Whether to put the self MLP update of g1 separately outside the concatenated MLP update.
ln_eps : float, optional
The epsilon value for layer normalization.
"""
Expand Down Expand Up @@ -265,6 +274,9 @@ def __init__(
self.update_residual_init = update_residual_init
self.set_davg_zero = set_davg_zero
self.trainable_ln = trainable_ln
self.use_sqrt_nnei = use_sqrt_nnei
self.g1_out_conv = g1_out_conv
self.g1_out_mlp = g1_out_mlp
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down Expand Up @@ -304,6 +316,9 @@ def serialize(self) -> dict:
"update_residual_init": self.update_residual_init,
"set_davg_zero": self.set_davg_zero,
"trainable_ln": self.trainable_ln,
"use_sqrt_nnei": self.use_sqrt_nnei,
"g1_out_conv": self.g1_out_conv,
"g1_out_mlp": self.g1_out_mlp,
"ln_eps": self.ln_eps,
}

Expand Down Expand Up @@ -448,6 +463,9 @@ def init_subclass_params(sub_data, sub_class):
env_protection=env_protection,
precision=precision,
trainable_ln=self.repformer_args.trainable_ln,
use_sqrt_nnei=self.repformer_args.use_sqrt_nnei,
g1_out_conv=self.repformer_args.g1_out_conv,
g1_out_mlp=self.repformer_args.g1_out_mlp,
ln_eps=self.repformer_args.ln_eps,
seed=child_seed(seed, 1),
)
Expand Down
Loading

0 comments on commit 498fc24

Please sign in to comment.