diff --git a/torch_geometric/nn/conv/han_conv.py b/torch_geometric/nn/conv/han_conv.py index e1e0128c7443..d406147e89f2 100644 --- a/torch_geometric/nn/conv/han_conv.py +++ b/torch_geometric/nn/conv/han_conv.py @@ -23,7 +23,7 @@ def group( num_edge_types = len(xs) out = torch.stack(xs) if out.numel() == 0: - return out.view(0, out.size(-1)) + return out.view(0, out.size(-1)), None attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1) attn = F.softmax(attn_score, dim=0) out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)