Skip to content

Commit

Permalink
add pre-ln
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 10, 2024
1 parent be45407 commit 3f98056
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 13 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,7 @@ def __init__(
use_undirect_a: bool = False,
update_g1_bidirect: bool = False,
pipeline_update: bool = False,
pre_ln: bool = False,
) -> None:
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -347,6 +348,7 @@ def __init__(
self.scale_dist = scale_dist
self.multiscale_mode = multiscale_mode
self.angle_only_cos = angle_only_cos
self.pre_ln = pre_ln
# to keep consistent with default value in this backends
if ln_eps is None:
ln_eps = 1e-5
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def init_subclass_params(sub_data, sub_class):
use_undirect_a=self.repformer_args.use_undirect_a,
update_g1_bidirect=self.repformer_args.update_g1_bidirect,
pipeline_update=self.repformer_args.pipeline_update,
pre_ln=self.repformer_args.pre_ln,
seed=child_seed(seed, 1),
)
self.rcsl_list = [
Expand Down
44 changes: 31 additions & 13 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def __init__(
use_undirect_a: bool = False,
update_g1_bidirect: bool = False,
pipeline_update: bool = False,
pre_ln: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -519,6 +520,7 @@ def __init__(
self.g2_dim = g2_dim
self.trainable_ln = trainable_ln
self.ln_eps = ln_eps
self.pre_ln = pre_ln
self.precision = precision
self.seed = seed
self.use_sqrt_nnei = use_sqrt_nnei
Expand Down Expand Up @@ -550,6 +552,9 @@ def __init__(
"const",
], "'update_residual_init' only support 'norm' or 'const'!"

if self.pre_ln:
assert self.update_style == "res_layer"

if self.update_style == "res_layer":
self.g1_layernorm = nn.LayerNorm(
self.g1_dim,
Expand Down Expand Up @@ -1086,6 +1091,12 @@ def forward(
# angle
a_update: list[torch.Tensor] = [angle_embed]

if self.pre_ln:
assert self.g1_layernorm is not None
assert self.g2_layernorm is not None
g1 = self.g1_layernorm(g1)
g2 = self.g2_layernorm(g2)

# g1 self mlp
g1_self_mlp = self.act(self.g1_self_mlp(g1))
g1_update.append(g1_self_mlp)
Expand Down Expand Up @@ -1227,6 +1238,9 @@ def forward(
g2_update.append(g2_2)

if self.has_angle:
if self.pre_ln:
assert self.angle_layernorm is not None
angle_embed = self.angle_layernorm(angle_embed)
assert self.angle_linear is not None
assert self.g2_angle_linear1 is not None
assert self.g2_angle_linear2 is not None
Expand Down Expand Up @@ -1374,20 +1388,24 @@ def list_update_res_layer(
uu = update_list[0]
for ii in range(1, nitem):
uu = uu + update_list[ii]
if update_name == "g1":
assert self.g1_layernorm is not None
return self.g1_layernorm(uu)
elif update_name == "g2":
assert self.g2_layernorm is not None
return self.g2_layernorm(uu)
elif update_name == "h2":
# not update h2
return uu
elif update_name == "a":
assert self.angle_layernorm is not None
return self.angle_layernorm(uu)
if not self.pre_ln:
if update_name == "g1":
assert self.g1_layernorm is not None
out = self.g1_layernorm(uu)
elif update_name == "g2":
assert self.g2_layernorm is not None
out = self.g2_layernorm(uu)
elif update_name == "h2":
# not update h2
out = uu
elif update_name == "a":
assert self.angle_layernorm is not None
out = self.angle_layernorm(uu)
else:
raise NotImplementedError
else:
raise NotImplementedError
out = uu
return out

@torch.jit.export
def list_update_res_residual(
Expand Down
15 changes: 15 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
use_undirect_a: bool = False,
update_g1_bidirect: bool = False,
pipeline_update: bool = False,
pre_ln: bool = False,
) -> None:
r"""
The repformer descriptor block.
Expand Down Expand Up @@ -277,6 +278,7 @@ def __init__(
self.use_undirect_a = use_undirect_a
self.update_g1_bidirect = update_g1_bidirect
self.pipeline_update = pipeline_update
self.pre_ln = pre_ln
if num_a % 2 != 1:
raise ValueError(f"{num_a=} must be an odd integer")
circular_harmonics_order = (num_a - 1) // 2
Expand All @@ -295,6 +297,14 @@ def __init__(
bias=False,
dtype=self.prec,
)
self.out_ln = None
if self.pre_ln:
self.out_ln = torch.nn.LayerNorm(
self.g1_dim,
device=env.DEVICE,
dtype=self.prec,
elementwise_affine=trainable_ln,
)
# order matters, placed after the assignment of self.ntypes
self.reinit_exclude(exclude_types)
self.env_protection = env_protection
Expand Down Expand Up @@ -370,6 +380,7 @@ def __init__(
use_undirect_a=self.use_undirect_a,
update_g1_bidirect=self.update_g1_bidirect,
pipeline_update=self.pipeline_update,
pre_ln=self.pre_ln,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down Expand Up @@ -677,6 +688,10 @@ def forward(
# (nb x nloc) x ng2 x 3
rot_mat = torch.permute(h2g2, (0, 1, 3, 2))

if self.pre_ln:
assert self.out_ln is not None
g1 = self.out_ln(g1)

return g1, g2, h2, rot_mat.view(nframes, nloc, self.dim_emb, 3), sw

def compute_input_stats(
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,12 @@ def dpa2_repformer_args():
optional=True,
default=True,
),
Argument(
"pre_ln",
bool,
optional=True,
default=False,
),
Argument(
"angle_only_cos",
bool,
Expand Down

0 comments on commit 3f98056

Please sign in to comment.