Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Sep 21, 2022
1 parent d7c31b9 commit 2778efc
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions torch_geometric/nn/conv/hgt_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,11 @@ def forward(
out_dict[node_type] = None
continue

out = out + self.a_lin[node_type](x_dict[node_type])
# out = self.a_lin[node_type](F.gelu(out))
# if out.size(-1) == x_dict[node_type].size(-1):
out = self.a_lin[node_type](F.gelu(out))
if out.size(-1) == x_dict[node_type].size(-1):

# alpha = self.skip[node_type].sigmoid()
# out = alpha * out + (1 - alpha) * x_dict[node_type]
alpha = self.skip[node_type].sigmoid()
out = alpha * out + (1 - alpha) * x_dict[node_type]
out_dict[node_type] = out

return out_dict
Expand All @@ -187,7 +186,6 @@ def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, rel: Tensor,
alpha = (q_i * k_j).sum(dim=-1) * rel
alpha = alpha / math.sqrt(q_i.size(-1))
alpha = softmax(alpha, index, ptr, size_i)
# alpha = alpha.sigmoid()
out = v_j * alpha.view(-1, self.heads, 1)
return out.view(-1, self.out_channels)

Expand Down

0 comments on commit 2778efc

Please sign in to comment.