Skip to content

Commit

Permalink
Update repformer_layer.py
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed May 9, 2024
1 parent bd25aa6 commit f17f40f
Showing 1 changed file with 38 additions and 38 deletions.
76 changes: 38 additions & 38 deletions deepmd/pt/model/descriptor/repformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,8 +785,8 @@ def _update_g1_conv(

@staticmethod
def _cal_hg(
g: torch.Tensor,
h: torch.Tensor,
g2: torch.Tensor,
h2: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
smooth: bool = True,
Expand All @@ -797,9 +797,9 @@ def _cal_hg(
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng.
h
g2
Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2.
h2
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.
Expand All @@ -814,61 +814,61 @@ def _cal_hg(
Returns
-------
hg
The transposed rotation matrix, with shape nb x nloc x 3 x ng.
The transposed rotation matrix, with shape nb x nloc x 3 x ng2.
"""
# g: nb x nloc x nnei x ng
# h: nb x nloc x nnei x 3
# g2: nb x nloc x nnei x ng2
# h2: nb x nloc x nnei x 3
# msk: nb x nloc x nnei
nb, nloc, nnei, _ = g.shape
ng = g.shape[-1]
# nb x nloc x nnei x ng
g = _apply_nlist_mask(g, nlist_mask)
nb, nloc, nnei, _ = g2.shape
ng2 = g2.shape[-1]
# nb x nloc x nnei x ng2
g2 = _apply_nlist_mask(g2, nlist_mask)
if not smooth:
# nb x nloc
# must use type_as here to convert bool to float, otherwise there will be numerical difference from numpy
invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g), dim=-1))
invnnei = 1.0 / (epsilon + torch.sum(nlist_mask.type_as(g2), dim=-1))
# nb x nloc x 1 x 1
invnnei = invnnei.unsqueeze(-1).unsqueeze(-1)
else:
g = _apply_switch(g, sw)
g2 = _apply_switch(g2, sw)
invnnei = (1.0 / float(nnei)) * torch.ones(
(nb, nloc, 1, 1), dtype=g.dtype, device=g.device
(nb, nloc, 1, 1), dtype=g2.dtype, device=g2.device
)
# nb x nloc x 3 x ng
hg = torch.matmul(torch.transpose(h, -1, -2), g) * invnnei
return hg
h2g2 = torch.matmul(torch.transpose(h2, -1, -2), g2) * invnnei
return h2g2

@staticmethod
def _cal_grrg(hg: torch.Tensor, axis_neuron: int) -> torch.Tensor:
def _cal_grrg(h2g2: torch.Tensor, axis_neuron: int) -> torch.Tensor:
"""
Calculate the atomic invariant rep.
Parameters
----------
hg
The transposed rotation matrix, with shape nb x nloc x 3 x ng.
h2g2
The transposed rotation matrix, with shape nb x nloc x 3 x ng2.
axis_neuron
Size of the submatrix.
Returns
-------
grrg
Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng)
Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2)
"""
# nb x nloc x 3 x ng
nb, nloc, _, ng = hg.shape
# nb x nloc x 3 x ng2
nb, nloc, _, ng2 = h2g2.shape
# nb x nloc x 3 x axis
hgm = torch.split(hg, axis_neuron, dim=-1)[0]
h2g2m = torch.split(h2g2, axis_neuron, dim=-1)[0]
# nb x nloc x axis_neuron x ng
grrg = torch.matmul(torch.transpose(hgm, -1, -2), hg) / (3.0**1)
grrg = torch.matmul(torch.transpose(h2g2m, -1, -2), h2g2) / (3.0**1)
# nb x nloc x (axis_neuron x ng)
grrg = grrg.view(nb, nloc, axis_neuron * ng)
grrg = grrg.view(nb, nloc, axis_neuron * ng2)
return grrg

def symmetrization_op(
self,
g: torch.Tensor,
h: torch.Tensor,
g2: torch.Tensor,
h2: torch.Tensor,
nlist_mask: torch.Tensor,
sw: torch.Tensor,
axis_neuron: int,
Expand All @@ -880,9 +880,9 @@ def symmetrization_op(
Parameters
----------
g
Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng.
h
g2
Neighbor-wise/Pair-wise invariant rep tensors, with shape nb x nloc x nnei x ng2.
h2
Neighbor-wise/Pair-wise equivariant rep tensors, with shape nb x nloc x nnei x 3.
nlist_mask
Neighbor list mask, where zero means no neighbor, with shape nb x nloc x nnei.
Expand All @@ -899,16 +899,16 @@ def symmetrization_op(
Returns
-------
grrg
Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng)
Atomic invariant rep, with shape nb x nloc x (axis_neuron x ng2)
"""
# g: nb x nloc x nnei x ng
# h: nb x nloc x nnei x 3
# g2: nb x nloc x nnei x ng2
# h2: nb x nloc x nnei x 3
# msk: nb x nloc x nnei
nb, nloc, nnei, _ = g.shape
nb, nloc, nnei, _ = g2.shape
# nb x nloc x 3 x ng
hg = self._cal_hg(g, h, nlist_mask, sw, smooth=smooth, epsilon=epsilon)
# nb x nloc x (axis_neuron x ng)
grrg = self._cal_grrg(hg, axis_neuron)
h2g2 = self._cal_hg(g2, h2, nlist_mask, sw, smooth=smooth, epsilon=epsilon)
# nb x nloc x (axis_neuron x ng2)
grrg = self._cal_grrg(h2g2, axis_neuron)
return grrg

def _update_g2_g1g1(
Expand Down

0 comments on commit f17f40f

Please sign in to comment.