Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GAT model - last layer incorrect? #4298

Closed
j-adamczyk opened this issue Mar 18, 2022 · 4 comments · Fixed by #4299
Closed

GAT model - last layer incorrect? #4298

j-adamczyk opened this issue Mar 18, 2022 · 4 comments · Fixed by #4299
Labels

Comments

@j-adamczyk
Copy link
Contributor

🐛 Describe the bug

In GAT paper last layer averages attention heads instead of concatenation, which corresponds to concat=False in GATConv in PyTorchGeometric. This is also the case in all examples.

However, in GAT model the last layer is still concatenated, as only the convolution is passed to BasicGNN. This means that using GAT, we actually concatenate the attention heads in the last layer, which is incorrect, or at least an unexpected behavior.

If I'm correct, I see two ways of fixing this:

  • leave as-is, but add a disclaimer to the docs, possibly with example how to go from the current output (num_nodes, K * out_channels) (for K attention heads) to (num_nodes, out_channels), where attention heads are averaged
  • override the .forward() method in GAT and add concat_last_layer=False option:
    • if True, run as-is
    • if False, override the .forward() method, use the parent's .forward() up to N-1 layer, and run the last layer separately, using concat=False`

Environment

  • PyG version: 2.0.4
  • PyTorch version: 1.10
  • OS: Windows 10
  • Python version: 3.9
  • CUDA/cuDNN version: 11.3
  • How you installed PyTorch and PyG (conda, pip, source): pip
  • Any other relevant information (e.g., version of torch-scatter):
@j-adamczyk j-adamczyk added the bug label Mar 18, 2022
@rusty1s
Copy link
Member

rusty1s commented Mar 18, 2022

That is interesting, I think you are right that we need to fix the concat option in the last layer for GAT. In my understanding, we only need to fix this in case GAT makes use of the out_channels argument, i.e. the last GATConv layer actually maps to out_channels. In that case, it might be the easiest to just fix concat=False.

@j-adamczyk
Copy link
Contributor Author

I think something like this would work:

class GAT(BasicGNN):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        kwargs = copy.copy(kwargs)
        kwargs["concat"] = False
        if out_channels is not None and jk is None:
            self.convs[-1] = self.init_conv(hidden_channels, out_channels, **kwargs)
        else:
            self.convs[-1] = self.init_conv(hidden_channels, hidden_channels, **kwargs)

So we always make sure that the last GATConv layer uses concat=False. This fixes the parameter, but according to the paper this makes sense, and previous layers can still use either concatenation or summation.

If this seems correct, I can make a PR.

@rusty1s
Copy link
Member

rusty1s commented Mar 18, 2022

Yeah, that is correct. I was too fast and went ahead and already fixed it, see #4299. I'm sorry :(

@j-adamczyk
Copy link
Contributor Author

Sure, thanks for fixing this!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants