From 78bba363c806336368b0027ed2a1da40b10e1ea9 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 2 Jan 2023 13:31:11 +0100 Subject: [PATCH] `GPSConv` Graph Transformer Layer (#6326) --- CHANGELOG.md | 1 + README.md | 1 + test/nn/conv/test_gps_conv.py | 28 +++++ test/nn/conv/test_message_passing.py | 10 +- torch_geometric/nn/conv/__init__.py | 2 + torch_geometric/nn/conv/gps_conv.py | 162 +++++++++++++++++++++++++++ 6 files changed, 199 insertions(+), 5 deletions(-) create mode 100644 test/nn/conv/test_gps_conv.py create mode 100644 torch_geometric/nn/conv/gps_conv.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9aec9f6b5834..ded78ca274bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [2.3.0] - 2023-MM-DD ### Added +- Added the `GPSConv` Graph Transformer layer ([#6326](https://github.com/pyg-team/pytorch_geometric/pull/6326)) - Added `networkit` conversion utilities ([#6321](https://github.com/pyg-team/pytorch_geometric/pull/6321)) - Added global dataset attribute access via `dataset.{attr_name}` ([#6319](https://github.com/pyg-team/pytorch_geometric/pull/6319)) - Added the `TransE` KGE model and example ([#6314](https://github.com/pyg-team/pytorch_geometric/pull/6314)) diff --git a/README.md b/README.md index c6a1e78986ee..b0468ac23c56 100644 --- a/README.md +++ b/README.md @@ -240,6 +240,7 @@ These GNN layers can be stacked together to create Graph Neural Network models. * A **[MetaLayer](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.meta.MetaLayer)** for building any kind of graph network similar to the [TensorFlow Graph Nets library](https://github.com/deepmind/graph_nets) from Battaglia *et al.*: [Relational Inductive Biases, Deep Learning, and Graph Networks](https://arxiv.org/abs/1806.01261) (CoRR 2018) * **[SSGConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.SSGConv)** from Zhu *et al.*: [Simple Spectral Graph Convolution](https://openreview.net/forum?id=CYO5T-YjWZV) (ICLR 2021) * **[FusedGATConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.FusedGATConv)** from Zhang *et al.*: [Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective](https://proceedings.mlsys.org/paper/2022/file/9a1158154dfa42caddbd0694a4e9bdc8-Paper.pdf) (MLSys 2022) +* **[GPSConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GPSConv)** from Rampášek *et al.*: [Recipe for a General, Powerful, Scalable Graph Transformer](https://arxiv.org/abs/2205.12454) (NeurIPS 2022) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_gps.py)] **Pooling layers:** diff --git a/test/nn/conv/test_gps_conv.py b/test/nn/conv/test_gps_conv.py new file mode 100644 index 000000000000..381d28067b5d --- /dev/null +++ b/test/nn/conv/test_gps_conv.py @@ -0,0 +1,28 @@ +import pytest +import torch +from torch_sparse import SparseTensor + +from torch_geometric.nn import GPSConv, SAGEConv + + +@pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm']) +def test_gps_conv(norm): + x = torch.randn(4, 16) + edge_index = torch.tensor([[0, 1, 1, 0, 2, 3, 2, 3], + [1, 0, 0, 1, 3, 2, 3, 2]]) + row, col = edge_index + adj_t = SparseTensor(row=col, col=row, sparse_sizes=(4, 4)) + batch = torch.tensor([0, 0, 1, 1]) + + conv = GPSConv(16, conv=SAGEConv(16, 16), heads=4, norm=norm) + conv.reset_parameters() + assert str(conv) == ('GPSConv(16, conv=SAGEConv(16, 16, aggr=mean), ' + 'heads=4)') + + out = conv(x, edge_index) + assert out.size() == (4, 16) + assert torch.allclose(conv(x, adj_t), out) + + out = conv(x, edge_index, batch) + assert out.size() == (4, 16) + assert torch.allclose(conv(x, adj_t, batch), out) diff --git a/test/nn/conv/test_message_passing.py b/test/nn/conv/test_message_passing.py index 96d85e39024b..ca30190664ab 100644 --- a/test/nn/conv/test_message_passing.py +++ b/test/nn/conv/test_message_passing.py @@ -62,10 +62,10 @@ def test_my_conv(): assert out.size() == (4, 32) assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out) assert torch.allclose(conv(x1, adj.t()), out) - assert torch.allclose(conv(x1, torch_adj.t()), out) + assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6) conv.fuse = False assert torch.allclose(conv(x1, adj.t()), out) - assert torch.allclose(conv(x1, torch_adj.t()), out) + assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6) conv.fuse = True adj = adj.sparse_resize((4, 2)) @@ -78,14 +78,14 @@ def test_my_conv(): assert out2.size() == (2, 32) assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1) assert torch.allclose(conv((x1, x2), adj.t()), out1) - assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) + assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj.t()), out2) assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6) conv.fuse = False assert torch.allclose(conv((x1, x2), adj.t()), out1) - assert torch.allclose(conv((x1, x2), torch_adj.t()), out1) + assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6) assert torch.allclose(conv((x1, None), adj.t()), out2) - assert torch.allclose(conv((x1, None), torch_adj.t()), out2) + assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6) conv.fuse = True # Test backward compatibility for `torch.sparse` tensors: diff --git a/torch_geometric/nn/conv/__init__.py b/torch_geometric/nn/conv/__init__.py index 1982eb730bcc..3a0548d03516 100644 --- a/torch_geometric/nn/conv/__init__.py +++ b/torch_geometric/nn/conv/__init__.py @@ -53,6 +53,7 @@ from .lg_conv import LGConv from .ssg_conv import SSGConv from .point_gnn_conv import PointGNNConv +from .gps_conv import GPSConv __all__ = [ 'MessagePassing', @@ -115,6 +116,7 @@ 'HANConv', 'LGConv', 'PointGNNConv', + 'GPSConv', ] classes = __all__ diff --git a/torch_geometric/nn/conv/gps_conv.py b/torch_geometric/nn/conv/gps_conv.py new file mode 100644 index 000000000000..d83c39931675 --- /dev/null +++ b/torch_geometric/nn/conv/gps_conv.py @@ -0,0 +1,162 @@ +import inspect +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Dropout, Linear, Sequential + +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.resolver import ( + activation_resolver, + normalization_resolver, +) +from torch_geometric.typing import Adj +from torch_geometric.utils import to_dense_batch + +from ..inits import reset + + +class GPSConv(torch.nn.Module): + r"""The general, powerful, scalable (GPS) graph transformer layer from the + `"Recipe for a General, Powerful, Scalable Graph Transformer" + `_ paper. + + The GPS layer is based on a 3-part recipe: + + 1. Inclusion of positional (PE) and structural encodings (SE) to the input + features (done in a pre-processing step via + :class:`torch_geometric.transforms`). + 2. A local message passing layer (MPNN) that operates on the input graph. + 3. A global attention layer that operates on the entire graph. + + .. note:: + + For an example of using :class:`GPSConv`, see + `examples/graph_gps.py + `_. + + Args: + channels (int): Size of each input sample. + conv (MessagePassing, optional): The local message passing layer. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + dropout (float, optional): Dropout probability of intermediate + embeddings. (default: :obj:`0.`) + attn_dropout (float, optional): Dropout probability of the normalized + attention coefficients. (default: :obj:`0`) + act (str or Callable, optional): The non-linear activation function to + use. (default: :obj:`"relu"`) + act_kwargs (Dict[str, Any], optional): Arguments passed to the + respective activation function defined by :obj:`act`. + (default: :obj:`None`) + norm (str or Callable, optional): The normalization function to + use. (default: :obj:`"batch_norm"`) + norm_kwargs (Dict[str, Any], optional): Arguments passed to the + respective normalization function defined by :obj:`norm`. + (default: :obj:`None`) + """ + def __init__( + self, + channels: int, + conv: Optional[MessagePassing], + heads: int = 1, + dropout: float = 0.0, + attn_dropout: float = 0.0, + act: str = 'relu', + act_kwargs: Optional[Dict[str, Any]] = None, + norm: Optional[str] = 'batch_norm', + norm_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__() + + self.channels = channels + self.conv = conv + self.heads = heads + self.dropout = dropout + + self.attn = torch.nn.MultiheadAttention( + channels, + heads, + dropout=attn_dropout, + batch_first=True, + ) + + self.mlp = Sequential( + Linear(channels, channels * 2), + activation_resolver(act, **(act_kwargs or {})), + Dropout(dropout), + Linear(channels * 2, channels), + Dropout(dropout), + ) + + norm_kwargs = norm_kwargs or {} + self.norm1 = normalization_resolver(norm, channels, **norm_kwargs) + self.norm2 = normalization_resolver(norm, channels, **norm_kwargs) + self.norm3 = normalization_resolver(norm, channels, **norm_kwargs) + + self.norm_with_batch = False + if self.norm1 is not None: + signature = inspect.signature(self.norm1.forward) + self.norm_with_batch = 'batch' in signature.parameters + + def reset_parameters(self): + if self.conv is not None: + self.conv.reset_parameters() + self.attn._reset_parameters() + reset(self.mlp) + if self.norm1 is not None: + self.norm1.reset_parameters() + if self.norm2 is not None: + self.norm2.reset_parameters() + if self.norm3 is not None: + self.norm3.reset_parameters() + + def forward( + self, + x: Tensor, + edge_index: Adj, + batch: Optional[torch.Tensor] = None, + **kwargs, + ) -> Tensor: + """""" + hs = [] + if self.conv is not None: # Local MPNN. + h = self.conv(x, edge_index, **kwargs) + h = F.dropout(h, p=self.dropout, training=self.training) + h = h + x + if self.norm1 is not None: + if self.norm_with_batch: + h = self.norm1(h, batch=batch) + else: + h = self.norm1(h) + hs.append(h) + + # Global attention transformer-style model. + h, mask = to_dense_batch(x, batch) + h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False) + h = h[mask] + h = F.dropout(h, p=self.dropout, training=self.training) + h = h + x # Residual connection. + if self.norm2 is not None: + if self.norm_with_batch: + h = self.norm2(h, batch=batch) + else: + h = self.norm2(h) + hs.append(h) + + out = sum(hs) # Combine local and global outputs. + + out = out + self.mlp(out) + if self.norm3 is not None: + if self.norm_with_batch: + h = self.norm3(h, batch=batch) + else: + h = self.norm3(h) + + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.channels}, ' + f'conv={self.conv}, heads={self.heads})')