From 498fc246fcc5a178afe5967c9ae389dad0c8e569 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Fri, 13 Sep 2024 12:26:27 +0800 Subject: [PATCH] update first three subs --- deepmd/dpmodel/descriptor/dpa2.py | 18 ++ deepmd/dpmodel/descriptor/repformers.py | 159 ++++++++++++++++-- deepmd/pt/model/descriptor/dpa2.py | 3 + deepmd/pt/model/descriptor/repformer_layer.py | 131 ++++++++++++--- deepmd/pt/model/descriptor/repformers.py | 23 ++- deepmd/utils/argcheck.py | 24 +++ .../tests/consistent/descriptor/test_dpa2.py | 31 +++- .../dpmodel/descriptor/test_descriptor.py | 9 + 8 files changed, 358 insertions(+), 40 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index 0de63bce4a..478a44f420 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -172,6 +172,9 @@ def __init__( update_residual_init: str = "norm", set_davg_zero: bool = True, trainable_ln: bool = True, + use_sqrt_nnei: bool = False, + g1_out_conv: bool = False, + g1_out_mlp: bool = False, ln_eps: Optional[float] = 1e-5, ): r"""The constructor for the RepformerArgs class which defines the parameters of the repformer block in DPA2 descriptor. @@ -236,6 +239,12 @@ def __init__( Set the normalization average to zero. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. ln_eps : float, optional The epsilon value for layer normalization. """ @@ -265,6 +274,9 @@ def __init__( self.update_residual_init = update_residual_init self.set_davg_zero = set_davg_zero self.trainable_ln = trainable_ln + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp # to keep consistent with default value in this backends if ln_eps is None: ln_eps = 1e-5 @@ -304,6 +316,9 @@ def serialize(self) -> dict: "update_residual_init": self.update_residual_init, "set_davg_zero": self.set_davg_zero, "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, } @@ -448,6 +463,9 @@ def init_subclass_params(sub_data, sub_class): env_protection=env_protection, precision=precision, trainable_ln=self.repformer_args.trainable_ln, + use_sqrt_nnei=self.repformer_args.use_sqrt_nnei, + g1_out_conv=self.repformer_args.g1_out_conv, + g1_out_mlp=self.repformer_args.g1_out_mlp, ln_eps=self.repformer_args.ln_eps, seed=child_seed(seed, 1), ) diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index bb84816d3d..fb18b61c39 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -118,6 +118,12 @@ class DescrptBlockRepformers(NativeOP, DescriptorBlock): For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. ln_eps : float, optional The epsilon value for layer normalization. seed : int, optional @@ -157,6 +163,9 @@ def __init__( env_protection: float = 0.0, precision: str = "float64", trainable_ln: bool = True, + use_sqrt_nnei: bool = False, + g1_out_conv: bool = False, + g1_out_mlp: bool = False, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, List[int]]] = None, ): @@ -200,6 +209,9 @@ def __init__( self.env_protection = env_protection self.precision = precision self.trainable_ln = trainable_ln + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp self.ln_eps = ln_eps self.epsilon = 1e-4 @@ -238,6 +250,9 @@ def __init__( trainable_ln=self.trainable_ln, ln_eps=self.ln_eps, precision=precision, + use_sqrt_nnei=self.use_sqrt_nnei, + g1_out_conv=self.g1_out_conv, + g1_out_mlp=self.g1_out_mlp, seed=child_seed(child_seed(seed, 1), ii), ) ) @@ -392,7 +407,15 @@ def call( ) # nf x nloc x 3 x ng2 - h2g2 = _cal_hg(g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon) + h2g2 = _cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=self.smooth, + epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) # (nf x nloc) x ng2 x 3 rot_mat = np.transpose(h2g2, (0, 1, 3, 2)) return g1, g2, h2, rot_mat.reshape(nf, nloc, self.dim_emb, 3), sw @@ -521,6 +544,7 @@ def _cal_hg( sw: np.ndarray, smooth: bool = True, epsilon: float = 1e-4, + use_sqrt_nnei: bool = False, ) -> np.ndarray: """ Calculate the transposed rotation matrix. @@ -540,6 +564,8 @@ def _cal_hg( Whether to use smoothness in processes such as attention weights calculation. epsilon Protection of 1./nnei. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. Returns ------- @@ -555,12 +581,20 @@ def _cal_hg( g = _apply_nlist_mask(g, nlist_mask) if not smooth: # nf x nloc - invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + if not use_sqrt_nnei: + invnnei = 1.0 / (epsilon + np.sum(nlist_mask, axis=-1)) + else: + invnnei = 1.0 / (epsilon + np.sqrt(np.sum(nlist_mask, axis=-1))) # nf x nloc x 1 x 1 invnnei = invnnei[:, :, np.newaxis, np.newaxis] else: g = _apply_switch(g, sw) - invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + if not use_sqrt_nnei: + invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1, 1), dtype=g.dtype) + else: + invnnei = (1.0 / (float(nnei) ** 0.5)) * np.ones( + (nf, nloc, 1, 1), dtype=g.dtype + ) # nf x nloc x 3 x ng hg = np.matmul(np.transpose(h, axes=(0, 1, 3, 2)), g) * invnnei return hg @@ -601,6 +635,7 @@ def symmetrization_op( axis_neuron: int, smooth: bool = True, epsilon: float = 1e-4, + use_sqrt_nnei: bool = False, ) -> np.ndarray: """ Symmetrization operator to obtain atomic invariant rep. @@ -622,6 +657,8 @@ def symmetrization_op( Whether to use smoothness in processes such as attention weights calculation. epsilon Protection of 1./nnei. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. Returns ------- @@ -633,7 +670,15 @@ def symmetrization_op( # msk: nf x nloc x nnei nf, nloc, nnei, _ = g.shape # nf x nloc x 3 x ng - hg = _cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + hg = _cal_hg( + g, + h, + nlist_mask, + sw, + smooth=smooth, + epsilon=epsilon, + use_sqrt_nnei=use_sqrt_nnei, + ) # nf x nloc x (axis_neuron x ng) grrg = _cal_grrg(hg, axis_neuron) return grrg @@ -1083,6 +1128,9 @@ def __init__( smooth: bool = True, precision: str = "float64", trainable_ln: bool = True, + use_sqrt_nnei: bool = False, + g1_out_conv: bool = False, + g1_out_mlp: bool = False, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, List[int]]] = None, ): @@ -1120,6 +1168,9 @@ def __init__( self.g1_dim = g1_dim self.g2_dim = g2_dim self.trainable_ln = trainable_ln + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp self.ln_eps = ln_eps self.precision = precision @@ -1177,14 +1228,52 @@ def __init__( seed=child_seed(seed, 3), ) ) - if self.update_g1_has_conv: - self.proj_g1g2 = NativeLayer( + if self.g1_out_mlp: + self.g1_self_mlp = NativeLayer( + g1_dim, g1_dim, - g2_dim, - bias=False, precision=precision, - seed=child_seed(seed, 4), + seed=child_seed(seed, 15), ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 16), + ) + ) + else: + self.g1_self_mlp = None + if self.update_g1_has_conv: + if not self.g1_out_conv: + self.proj_g1g2 = NativeLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + else: + self.proj_g1g2 = NativeLayer( + g2_dim, + g1_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 17), + ) + ) if self.update_g2_has_g1g1: self.proj_g1g1g2 = NativeLayer( g1_dim, @@ -1270,12 +1359,12 @@ def __init__( ) def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: - ret = g1d + ret = g1d if not self.g1_out_mlp else 0 if self.update_g1_has_grrg: ret += g2d * ax if self.update_g1_has_drrd: ret += g1d * ax - if self.update_g1_has_conv: + if self.update_g1_has_conv and not self.g1_out_conv: ret += g2d return ret @@ -1325,9 +1414,13 @@ def _update_g1_conv( nf, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] - # gg1 : nf x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) - # nf x nloc x nnei x ng2 + if not self.g1_out_conv: + # gg1 : nf x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).reshape(nf, nloc, nnei, ng2) + else: + # gg1 : nf x nloc x nnei x ng1 + gg1 = gg1.reshape(nf, nloc, nnei, ng1) + # nf x nloc x nnei x ng2/ng1 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth @@ -1338,8 +1431,14 @@ def _update_g1_conv( else: gg1 = _apply_switch(gg1, sw) invnnei = (1.0 / float(nnei)) * np.ones((nf, nloc, 1), dtype=gg1.dtype) - # nf x nloc x ng2 - g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + if not self.g1_out_conv: + # nf x nloc x ng2 + g1_11 = np.sum(g2 * gg1, axis=2) * invnnei + else: + # nf x nloc x ng1 + g2 = self.proj_g1g2(g2).reshape(nf, nloc, nnei, ng1) + # nb x nloc x ng1 + g1_11 = np.sum(g2 * gg1, axis=2) * invnnei return g1_11 def _update_g2_g1g1( @@ -1412,7 +1511,11 @@ def call( g2_update: List[np.ndarray] = [g2] h2_update: List[np.ndarray] = [h2] g1_update: List[np.ndarray] = [g1] - g1_mlp: List[np.ndarray] = [g1] + g1_mlp: List[np.ndarray] = [g1] if not self.g1_out_mlp else [] + if self.g1_out_mlp: + assert self.g1_self_mlp is not None + g1_self_mlp = self.act(self.g1_self_mlp(g1)) + g1_update.append(g1_self_mlp) if cal_gg1: gg1 = _make_nei_g1(g1_ext, nlist) @@ -1454,7 +1557,11 @@ def call( if self.update_g1_has_conv: assert gg1 is not None - g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) + if not self.g1_out_conv: + g1_mlp.append(g1_conv) + else: + g1_update.append(g1_conv) if self.update_g1_has_grrg: g1_mlp.append( @@ -1466,6 +1573,7 @@ def call( self.axis_neuron, smooth=self.smooth, epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, ) ) @@ -1480,6 +1588,7 @@ def call( self.axis_neuron, smooth=self.smooth, epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, ) ) @@ -1586,6 +1695,9 @@ def serialize(self) -> dict: "smooth": self.smooth, "precision": self.precision, "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, "linear1": self.linear1.serialize(), } @@ -1633,6 +1745,12 @@ def serialize(self) -> dict: "loc_attn": self.loc_attn.serialize(), } ) + if self.g1_out_mlp: + data.update( + { + "g1_self_mlp": self.g1_self_mlp.serialize(), + } + ) if self.update_style == "res_residual": data.update( { @@ -1663,6 +1781,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": update_h2 = data["update_h2"] update_g1_has_attn = data["update_g1_has_attn"] update_style = data["update_style"] + g1_out_mlp = data["g1_out_mlp"] linear2 = data.pop("linear2", None) proj_g1g2 = data.pop("proj_g1g2", None) @@ -1672,6 +1791,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": attn2_lm = data.pop("attn2_lm", None) attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) + g1_self_mlp = data.pop("g1_self_mlp", None) g1_residual = data.pop("g1_residual", []) g2_residual = data.pop("g2_residual", []) h2_residual = data.pop("h2_residual", []) @@ -1701,6 +1821,9 @@ def deserialize(cls, data: dict) -> "RepformerLayer": if update_g1_has_attn: assert isinstance(loc_attn, dict) obj.loc_attn = LocalAtten.deserialize(loc_attn) + if g1_out_mlp: + assert isinstance(g1_self_mlp, dict) + obj.g1_self_mlp = NativeLayer.deserialize(g1_self_mlp) if update_style == "res_residual": obj.g1_residual = g1_residual obj.g2_residual = g2_residual diff --git a/deepmd/pt/model/descriptor/dpa2.py b/deepmd/pt/model/descriptor/dpa2.py index f13c8861ef..be6956dc45 100644 --- a/deepmd/pt/model/descriptor/dpa2.py +++ b/deepmd/pt/model/descriptor/dpa2.py @@ -211,6 +211,9 @@ def init_subclass_params(sub_data, sub_class): precision=precision, trainable_ln=self.repformer_args.trainable_ln, ln_eps=self.repformer_args.ln_eps, + use_sqrt_nnei=self.repformer_args.use_sqrt_nnei, + g1_out_conv=self.repformer_args.g1_out_conv, + g1_out_mlp=self.repformer_args.g1_out_mlp, seed=child_seed(seed, 1), old_impl=old_impl, ) diff --git a/deepmd/pt/model/descriptor/repformer_layer.py b/deepmd/pt/model/descriptor/repformer_layer.py index 85a9800c73..eccd445935 100644 --- a/deepmd/pt/model/descriptor/repformer_layer.py +++ b/deepmd/pt/model/descriptor/repformer_layer.py @@ -599,6 +599,9 @@ def __init__( precision: str = "float64", trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, + use_sqrt_nnei: bool = False, + g1_out_conv: bool = False, + g1_out_mlp: bool = False, seed: Optional[Union[int, List[int]]] = None, ): super().__init__() @@ -638,6 +641,9 @@ def __init__( self.ln_eps = ln_eps self.precision = precision self.seed = seed + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp assert update_residual_init in [ "norm", @@ -693,14 +699,52 @@ def __init__( seed=child_seed(seed, 3), ) ) - if self.update_g1_has_conv: - self.proj_g1g2 = MLPLayer( + if self.g1_out_mlp: + self.g1_self_mlp = MLPLayer( + g1_dim, g1_dim, - g2_dim, - bias=False, precision=precision, - seed=child_seed(seed, 4), + seed=child_seed(seed, 15), ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 16), + ) + ) + else: + self.g1_self_mlp = None + if self.update_g1_has_conv: + if not self.g1_out_conv: + self.proj_g1g2 = MLPLayer( + g1_dim, + g2_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + else: + self.proj_g1g2 = MLPLayer( + g2_dim, + g1_dim, + bias=False, + precision=precision, + seed=child_seed(seed, 4), + ) + if self.update_style == "res_residual": + self.g1_residual.append( + get_residual( + g1_dim, + self.update_residual, + self.update_residual_init, + precision=precision, + seed=child_seed(seed, 17), + ) + ) if self.update_g2_has_g1g1: self.proj_g1g1g2 = MLPLayer( g1_dim, @@ -790,12 +834,12 @@ def __init__( self.h2_residual = nn.ParameterList(self.h2_residual) def cal_1_dim(self, g1d: int, g2d: int, ax: int) -> int: - ret = g1d + ret = g1d if not self.g1_out_mlp else 0 if self.update_g1_has_grrg: ret += g2d * ax if self.update_g1_has_drrd: ret += g1d * ax - if self.update_g1_has_conv: + if self.update_g1_has_conv and not self.g1_out_conv: ret += g2d return ret @@ -845,9 +889,12 @@ def _update_g1_conv( nb, nloc, nnei, _ = g2.shape ng1 = gg1.shape[-1] ng2 = g2.shape[-1] - # gg1 : nb x nloc x nnei x ng2 - gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) - # nb x nloc x nnei x ng2 + if not self.g1_out_conv: + # gg1 : nb x nloc x nnei x ng2 + gg1 = self.proj_g1g2(gg1).view(nb, nloc, nnei, ng2) + else: + gg1 = gg1.view(nb, nloc, nnei, ng1) + # nb x nloc x nnei x ng2/ng1 gg1 = _apply_nlist_mask(gg1, nlist_mask) if not self.smooth: # normalized by number of neighbors, not smooth @@ -861,8 +908,13 @@ def _update_g1_conv( invnnei = (1.0 / float(nnei)) * torch.ones( (nb, nloc, 1), dtype=gg1.dtype, device=gg1.device ) - # nb x nloc x ng2 - g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + if not self.g1_out_conv: + # nb x nloc x ng2 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei + else: + g2 = self.proj_g1g2(g2).view(nb, nloc, nnei, ng1) + # nb x nloc x ng1 + g1_11 = torch.sum(g2 * gg1, dim=2) * invnnei return g1_11 @staticmethod @@ -873,6 +925,7 @@ def _cal_hg( sw: torch.Tensor, smooth: bool = True, epsilon: float = 1e-4, + use_sqrt_nnei: bool = False, ) -> torch.Tensor: """ Calculate the transposed rotation matrix. @@ -908,14 +961,24 @@ def _cal_hg( if not smooth: # nb x nloc # must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy - invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) + if not use_sqrt_nnei: + invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1)) + else: + invnnei = 1.0 / ( + epsilon + torch.sqrt(torch.sum(nlist_mask.type_as(g2), dim=-1)) + ) # nb x nloc x 1 x 1 invnnei = invnnei.unsqueeze(-1).unsqueeze(-1) else: g2 = _apply_switch(g2, sw) - invnnei = (1.0 / float(nnei)) * torch.ones( - (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device - ) + if not use_sqrt_nnei: + invnnei = (1.0 / float(nnei)) * torch.ones( + (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device + ) + else: + invnnei = (1.0 / (float(nnei) ** 0.5)) * torch.ones( + (nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device + ) # nb x nloc x 3 x ng2 h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei return h2g2 @@ -988,7 +1051,15 @@ def symmetrization_op( # msk: nb x nloc x nnei nb, nloc, nnei, _ = g2.shape # nb x nloc x 3 x ng2 - h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon) + h2g2 = self._cal_hg( + g2, + h2, + nlist_mask, + sw, + smooth=smooth, + epsilon=epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, + ) # nb x nloc x (axisxng2) g1_13 = self._cal_grrg(h2g2, axis_neuron) return g1_13 @@ -1063,7 +1134,11 @@ def forward( g2_update: List[torch.Tensor] = [g2] h2_update: List[torch.Tensor] = [h2] g1_update: List[torch.Tensor] = [g1] - g1_mlp: List[torch.Tensor] = [g1] + g1_mlp: List[torch.Tensor] = [g1] if not self.g1_out_mlp else [] + if self.g1_out_mlp: + assert self.g1_self_mlp is not None + g1_self_mlp = self.act(self.g1_self_mlp(g1)) + g1_update.append(g1_self_mlp) if cal_gg1: gg1 = _make_nei_g1(g1_ext, nlist) @@ -1105,7 +1180,11 @@ def forward( if self.update_g1_has_conv: assert gg1 is not None - g1_mlp.append(self._update_g1_conv(gg1, g2, nlist_mask, sw)) + g1_conv = self._update_g1_conv(gg1, g2, nlist_mask, sw) + if not self.g1_out_conv: + g1_mlp.append(g1_conv) + else: + g1_update.append(g1_conv) if self.update_g1_has_grrg: g1_mlp.append( @@ -1242,6 +1321,9 @@ def serialize(self) -> dict: "smooth": self.smooth, "precision": self.precision, "trainable_ln": self.trainable_ln, + "use_sqrt_nnei": self.use_sqrt_nnei, + "g1_out_conv": self.g1_out_conv, + "g1_out_mlp": self.g1_out_mlp, "ln_eps": self.ln_eps, "linear1": self.linear1.serialize(), } @@ -1289,6 +1371,12 @@ def serialize(self) -> dict: "loc_attn": self.loc_attn.serialize(), } ) + if self.g1_out_mlp: + data.update( + { + "g1_self_mlp": self.g1_self_mlp.serialize(), + } + ) if self.update_style == "res_residual": data.update( { @@ -1319,6 +1407,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": update_h2 = data["update_h2"] update_g1_has_attn = data["update_g1_has_attn"] update_style = data["update_style"] + g1_out_mlp = data["g1_out_mlp"] linear2 = data.pop("linear2", None) proj_g1g2 = data.pop("proj_g1g2", None) @@ -1328,6 +1417,7 @@ def deserialize(cls, data: dict) -> "RepformerLayer": attn2_lm = data.pop("attn2_lm", None) attn2_ev_apply = data.pop("attn2_ev_apply", None) loc_attn = data.pop("loc_attn", None) + g1_self_mlp = data.pop("g1_self_mlp", None) g1_residual = data.pop("g1_residual", []) g2_residual = data.pop("g2_residual", []) h2_residual = data.pop("h2_residual", []) @@ -1357,6 +1447,9 @@ def deserialize(cls, data: dict) -> "RepformerLayer": if update_g1_has_attn: assert isinstance(loc_attn, dict) obj.loc_attn = LocalAtten.deserialize(loc_attn) + if g1_out_mlp: + assert isinstance(g1_self_mlp, dict) + obj.g1_self_mlp = MLPLayer.deserialize(g1_self_mlp) if update_style == "res_residual": for ii, t in enumerate(obj.g1_residual): t.data = to_torch_tensor(g1_residual[ii]) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index bc8c331ec3..f08595bd96 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -105,6 +105,9 @@ def __init__( trainable_ln: bool = True, ln_eps: Optional[float] = 1e-5, seed: Optional[Union[int, List[int]]] = None, + use_sqrt_nnei: bool = False, + g1_out_conv: bool = False, + g1_out_mlp: bool = False, old_impl: bool = False, ): r""" @@ -182,6 +185,12 @@ def __init__( For example, when using paddings, there may be zero distances of neighbors, which may make division by zero error during environment matrix calculations without protection. trainable_ln : bool, optional Whether to use trainable shift and scale weights in layer normalization. + use_sqrt_nnei : bool, optional + Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly. + g1_out_conv : bool, optional + Whether to put the convolutional update of g1 separately outside the concatenated MLP update. + g1_out_mlp : bool, optional + Whether to put the self MLP update of g1 separately outside the concatenated MLP update. ln_eps : float, optional The epsilon value for layer normalization. seed : int, optional @@ -222,6 +231,9 @@ def __init__( self.direct_dist = direct_dist self.act = ActivationFn(activation_function) self.smooth = smooth + self.use_sqrt_nnei = use_sqrt_nnei + self.g1_out_conv = g1_out_conv + self.g1_out_mlp = g1_out_mlp # order matters, placed after the assignment of self.ntypes self.reinit_exclude(exclude_types) self.env_protection = env_protection @@ -296,6 +308,9 @@ def __init__( trainable_ln=self.trainable_ln, ln_eps=self.ln_eps, precision=precision, + use_sqrt_nnei=self.use_sqrt_nnei, + g1_out_conv=self.g1_out_conv, + g1_out_mlp=self.g1_out_mlp, seed=child_seed(child_seed(seed, 1), ii), ) ) @@ -500,7 +515,13 @@ def forward( # nb x nloc x 3 x ng2 h2g2 = RepformerLayer._cal_hg( - g2, h2, nlist_mask, sw, smooth=self.smooth, epsilon=self.epsilon + g2, + h2, + nlist_mask, + sw, + smooth=self.smooth, + epsilon=self.epsilon, + use_sqrt_nnei=self.use_sqrt_nnei, ) # (nb x nloc) x ng2 x 3 rot_mat = torch.permute(h2g2, (0, 1, 3, 2)) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 03176c601a..52acee74da 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1047,6 +1047,9 @@ def dpa2_repformer_args(): doc_update_g1_has_attn = "Update the g1 rep with the localized self-attention." doc_update_g2_has_g1g1 = "Update the g2 rep with the g1xg1 term." doc_update_g2_has_attn = "Update the g2 rep with the gated self-attention." + doc_use_sqrt_nnei = "Whether to use the square root of the number of neighbors for symmetrization_op normalization instead of using the number of neighbors directly." + doc_g1_out_conv = "Whether to put the convolutional update of g1 separately outside the concatenated MLP update." + doc_g1_out_mlp = "Whether to put the self MLP update of g1 separately outside the concatenated MLP update." doc_update_h2 = "Update the h2 rep." doc_attn1_hidden = ( "The hidden dimension of localized self-attention to update the g1 rep." @@ -1167,6 +1170,27 @@ def dpa2_repformer_args(): default=True, doc=doc_update_g2_has_attn, ), + Argument( + "use_sqrt_nnei", + bool, + optional=True, + default=False, + doc=doc_use_sqrt_nnei, + ), + Argument( + "g1_out_conv", + bool, + optional=True, + default=False, + doc=doc_g1_out_conv, + ), + Argument( + "g1_out_mlp", + bool, + optional=True, + default=False, + doc=doc_g1_out_mlp, + ), Argument( "update_h2", bool, diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 9b88b4238a..8a041a4dd0 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -37,7 +37,7 @@ RepinitArgs, ) from deepmd.utils.argcheck import ( - descrpt_se_atten_args, + descrpt_dpa2_args, ) @@ -59,6 +59,9 @@ (True,), # repformer_set_davg_zero (True,), # repformer_trainable_ln (1e-5,), # repformer_ln_eps + (True,), # repformer_use_sqrt_nnei + (True,), # repformer_g1_out_conv + (True,), # repformer_g1_out_mlp (True, False), # smooth ([], [[0, 1]]), # exclude_types ("float64",), # precision @@ -87,6 +90,9 @@ def data(self) -> dict: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -141,6 +147,9 @@ def data(self) -> dict: "set_davg_zero": True, "trainable_ln": repformer_trainable_ln, "ln_eps": repformer_ln_eps, + "use_sqrt_nnei": repformer_use_sqrt_nnei, + "g1_out_conv": repformer_g1_out_conv, + "g1_out_mlp": repformer_g1_out_mlp, } ), # kwargs for descriptor @@ -176,6 +185,9 @@ def skip_pt(self) -> bool: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -205,6 +217,9 @@ def skip_dp(self) -> bool: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -234,6 +249,9 @@ def skip_tf(self) -> bool: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -246,7 +264,7 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT - args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) def setUp(self): CommonTest.setUp(self) @@ -299,6 +317,9 @@ def setUp(self): repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -361,6 +382,9 @@ def rtol(self) -> float: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, @@ -396,6 +420,9 @@ def atol(self) -> float: repformer_set_davg_zero, repformer_trainable_ln, repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, smooth, exclude_types, precision, diff --git a/source/tests/universal/dpmodel/descriptor/test_descriptor.py b/source/tests/universal/dpmodel/descriptor/test_descriptor.py index 424dd2ea39..d9290cc80e 100644 --- a/source/tests/universal/dpmodel/descriptor/test_descriptor.py +++ b/source/tests/universal/dpmodel/descriptor/test_descriptor.py @@ -337,6 +337,9 @@ def DescriptorParamDPA2( repformer_set_davg_zero=False, repformer_trainable_ln=True, repformer_ln_eps=1e-5, + repformer_use_sqrt_nnei=False, + repformer_g1_out_conv=False, + repformer_g1_out_mlp=False, smooth=True, add_tebd_to_repinit_out=True, use_econf_tebd=False, @@ -392,6 +395,9 @@ def DescriptorParamDPA2( "set_davg_zero": repformer_set_davg_zero, "trainable_ln": repformer_trainable_ln, "ln_eps": repformer_ln_eps, + "use_sqrt_nnei": repformer_use_sqrt_nnei, + "g1_out_conv": repformer_g1_out_conv, + "g1_out_mlp": repformer_g1_out_mlp, } ), # kwargs for descriptor @@ -431,6 +437,9 @@ def DescriptorParamDPA2( "repformer_set_davg_zero": (True,), "repformer_trainable_ln": (True,), "repformer_ln_eps": (1e-5,), + "repformer_use_sqrt_nnei": (True,), + "repformer_g1_out_conv": (True,), + "repformer_g1_out_mlp": (True,), "smooth": (True, False), "exclude_types": ([], [[0, 1]]), "precision": ("float64",),