Skip to content

Commit

Permalink
add compress a
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Dec 19, 2024
1 parent 2aa6ca2 commit a5ec73f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 5 deletions.
2 changes: 2 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def __init__(
update_g2_has_ar: bool = False,
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
compress_a: bool = False,
) -> None:
r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor.
Expand Down Expand Up @@ -367,6 +368,7 @@ def __init__(
self.update_g1_bidirect = update_g1_bidirect
self.pipeline_update = pipeline_update
self.g1_mess_mulmlp = g1_mess_mulmlp
self.compress_a = compress_a

def __getitem__(self, key):
if hasattr(self, key):
Expand Down
1 change: 1 addition & 0 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ def init_subclass_params(sub_data, sub_class):
update_g2_has_ar=self.repformer_args.update_g2_has_ar,
update_g1_has_ar=self.repformer_args.update_g1_has_ar,
update_g2_has_arra=self.repformer_args.update_g2_has_arra,
compress_a=self.repformer_args.compress_a,
seed=child_seed(seed, 1),
)
self.no_repinit = self.repformer_args.no_repinit
Expand Down
39 changes: 34 additions & 5 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def __init__(
update_g2_has_ar: bool = False,
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
compress_a: bool = False,
seed: Optional[Union[int, list[int]]] = None,
) -> None:
super().__init__()
Expand Down Expand Up @@ -550,6 +551,7 @@ def __init__(
self.update_g2_has_ar = update_g2_has_ar
self.update_g1_has_ar = update_g1_has_ar
self.update_g2_has_arra = update_g2_has_arra
self.compress_a = compress_a
self.prec = PRECISION_DICT[precision]
self.g1_layernorm = None
self.g2_layernorm = None
Expand Down Expand Up @@ -810,10 +812,30 @@ def __init__(
)
angle_seed = 20
self.angle_dim = self.a_dim
self.angle_dim += self.g1_dim if self.update_a_has_g1 else 0
self.angle_dim += 2 * self.g2_dim if self.update_a_has_g2 else 0
self.g2_angle_dim = 2 * self.g2_dim + self.g1_dim
self.g2_angle_dim += self.a_dim if self.update_g2_has_a else 0
if not self.compress_a:
self.angle_dim += self.g1_dim if self.update_a_has_g1 else 0
self.angle_dim += 2 * self.g2_dim if self.update_a_has_g2 else 0
self.compress_n_linear = None
self.compress_e_linear = None
else:
self.angle_dim += self.a_dim if self.update_a_has_g1 else 0
self.angle_dim += self.a_dim if self.update_a_has_g2 else 0
self.compress_n_linear = MLPLayer(
self.g1_dim,
self.a_dim,
precision=precision,
bias=False,
seed=child_seed(seed, angle_seed + 3),
)
self.compress_e_linear = MLPLayer(
self.g2_dim * 2,
self.a_dim,
precision=precision,
bias=False,
seed=child_seed(seed, angle_seed + 2),
)

self.g2_angle_dim = self.angle_dim
self.angle_linear = MLPLayer(
self.angle_dim,
self.a_dim,
Expand Down Expand Up @@ -1368,10 +1390,17 @@ def forward(
g2_angle_j = torch.tile(g2_angle.unsqueeze(3), (1, 1, 1, self.a_sel, 1))
# nb x nloc x a_nnei x a_nnei x (g2 + g2)
g2_angle_embed = torch.cat([g2_angle_i, g2_angle_j], dim=-1)
if self.compress_a:
assert self.compress_n_linear is not None
assert self.compress_e_linear is not None
# nb x nloc x a_nnei x a_nnei x a_dim
g1_angle_embed = self.compress_n_linear(g1_angle_embed)
# nb x nloc x a_nnei x a_nnei x a_dim
g2_angle_embed = self.compress_e_linear(g2_angle_embed)

# angle for g2:
updated_g2_angle_list = [angle_embed] if self.update_g2_has_a else []
# nb x nloc x a_nnei x a_nnei x (a + g1 + g2*2)
# nb x nloc x a_nnei x a_nnei x (a + g1 + g2*2) or (a + a + a)
updated_g2_angle_list += [g1_angle_embed, g2_angle_embed]
updated_g2_angle = torch.cat(updated_g2_angle_list, dim=-1)
# nb x nloc x a_nnei x a_nnei x g2
Expand Down
3 changes: 3 additions & 0 deletions deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
update_g2_has_ar: bool = False,
update_g1_has_ar: bool = False,
update_g2_has_arra: bool = False,
compress_a: bool = False,
) -> None:
r"""
The repformer descriptor block.
Expand Down Expand Up @@ -287,6 +288,7 @@ def __init__(
self.update_g2_has_ar = update_g2_has_ar
self.update_g1_has_ar = update_g1_has_ar
self.update_g2_has_arra = update_g2_has_arra
self.compress_a = compress_a
if num_a % 2 != 1:
raise ValueError(f"{num_a=} must be an odd integer")
circular_harmonics_order = (num_a - 1) // 2
Expand Down Expand Up @@ -393,6 +395,7 @@ def __init__(
update_g2_has_ar=self.update_g2_has_ar,
update_g1_has_ar=self.update_g1_has_ar,
update_g2_has_arra=self.update_g2_has_arra,
compress_a=self.compress_a,
seed=child_seed(child_seed(seed, 1), ii),
)
)
Expand Down
6 changes: 6 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,12 @@ def dpa2_repformer_args():
optional=True,
default=True,
),
Argument(
"compress_a",
bool,
optional=True,
default=False,
),
Argument(
"update_a_has_g2",
bool,
Expand Down

0 comments on commit a5ec73f

Please sign in to comment.