Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanqidu authored Sep 24, 2023
1 parent 8952f29 commit 83da1a0
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(

self.hidden_channels = hidden_channels
self.num_radial = num_radial
self.dir_proj = nn.Sequential(
self.inv_proj = nn.Sequential(
nn.Linear(3 * self.hidden_channels + self.num_radial, self.hidden_channels * 3), nn.SiLU(inplace=True),
nn.Linear(self.hidden_channels * 3, self.hidden_channels * 3), )

Expand Down Expand Up @@ -131,7 +131,7 @@ def forward(self, x, vec, edge_index, edge_rbf, weight, edge_vector):
xh = self.x_proj(x)

rbfh = self.rbf_proj(edge_rbf)
weight = self.dir_proj(weight)
weight = self.inv_proj(weight)
rbfh = rbfh * weight
# propagate_type: (xh: Tensor, vec: Tensor, rbfh_ij: Tensor, r_ij: Tensor)
dx, dvec = self.propagate(
Expand Down Expand Up @@ -177,10 +177,10 @@ def __init__(self, hidden_channels):
super().__init__()
self.hidden_channels = hidden_channels

self.vec_proj = nn.Linear(
self.equi_proj = nn.Linear(
hidden_channels, hidden_channels * 2, bias=False
)
self.xvec_proj = nn.Sequential(
self.xequi_proj = nn.Sequential(
nn.Linear(hidden_channels * 2, hidden_channels),
nn.SiLU(),
nn.Linear(hidden_channels, hidden_channels * 3),
Expand All @@ -192,15 +192,15 @@ def __init__(self, hidden_channels):
self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.vec_proj.weight)
nn.init.xavier_uniform_(self.xvec_proj[0].weight)
self.xvec_proj[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.xvec_proj[2].weight)
self.xvec_proj[2].bias.data.fill_(0)
nn.init.xavier_uniform_(self.equi_proj.weight)
nn.init.xavier_uniform_(self.xequi_proj[0].weight)
self.xequi_proj[0].bias.data.fill_(0)
nn.init.xavier_uniform_(self.xequi_proj[2].weight)
self.xequi_proj[2].bias.data.fill_(0)

def forward(self, x, vec, node_frame):

vec = self.vec_proj(vec)
vec = self.equi_proj(vec)
vec1,vec2 = torch.split(
vec, self.hidden_channels, dim=-1
)
Expand All @@ -212,7 +212,7 @@ def forward(self, x, vec, node_frame):
vec_dot = (vec1 * vec2).sum(dim=1)
vec_dot = vec_dot * self.inv_sqrt_h

x_vec_h = self.xvec_proj(
x_vec_h = self.xequi_proj(
torch.cat(
[x, scalar], dim=-1
)
Expand Down

0 comments on commit 83da1a0

Please sign in to comment.