From 2778efc6f76249c7d682648645f30ba917307343 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 21 Sep 2022 14:10:19 +0000 Subject: [PATCH] update --- torch_geometric/nn/conv/hgt_conv.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torch_geometric/nn/conv/hgt_conv.py b/torch_geometric/nn/conv/hgt_conv.py index 3d7c24e453f2..d2d156e1e2a2 100644 --- a/torch_geometric/nn/conv/hgt_conv.py +++ b/torch_geometric/nn/conv/hgt_conv.py @@ -170,12 +170,11 @@ def forward( out_dict[node_type] = None continue - out = out + self.a_lin[node_type](x_dict[node_type]) - # out = self.a_lin[node_type](F.gelu(out)) - # if out.size(-1) == x_dict[node_type].size(-1): + out = self.a_lin[node_type](F.gelu(out)) + if out.size(-1) == x_dict[node_type].size(-1): - # alpha = self.skip[node_type].sigmoid() - # out = alpha * out + (1 - alpha) * x_dict[node_type] + alpha = self.skip[node_type].sigmoid() + out = alpha * out + (1 - alpha) * x_dict[node_type] out_dict[node_type] = out return out_dict @@ -187,7 +186,6 @@ def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, rel: Tensor, alpha = (q_i * k_j).sum(dim=-1) * rel alpha = alpha / math.sqrt(q_i.size(-1)) alpha = softmax(alpha, index, ptr, size_i) - # alpha = alpha.sigmoid() out = v_j * alpha.view(-1, self.heads, 1) return out.view(-1, self.out_channels)