Skip to content

Commit

Permalink
GPSConv Graph Transformer Layer (#6326)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Jan 2, 2023
1 parent abc3ad2 commit 78bba36
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.3.0] - 2023-MM-DD
### Added
- Added the `GPSConv` Graph Transformer layer ([#6326](https://github.com/pyg-team/pytorch_geometric/pull/6326))
- Added `networkit` conversion utilities ([#6321](https://github.com/pyg-team/pytorch_geometric/pull/6321))
- Added global dataset attribute access via `dataset.{attr_name}` ([#6319](https://github.com/pyg-team/pytorch_geometric/pull/6319))
- Added the `TransE` KGE model and example ([#6314](https://github.com/pyg-team/pytorch_geometric/pull/6314))
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ These GNN layers can be stacked together to create Graph Neural Network models.
* A **[MetaLayer](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.meta.MetaLayer)** for building any kind of graph network similar to the [TensorFlow Graph Nets library](https://github.com/deepmind/graph_nets) from Battaglia *et al.*: [Relational Inductive Biases, Deep Learning, and Graph Networks](https://arxiv.org/abs/1806.01261) (CoRR 2018)
* **[SSGConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.SSGConv)** from Zhu *et al.*: [Simple Spectral Graph Convolution](https://openreview.net/forum?id=CYO5T-YjWZV) (ICLR 2021)
* **[FusedGATConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.FusedGATConv)** from Zhang *et al.*: [Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective](https://proceedings.mlsys.org/paper/2022/file/9a1158154dfa42caddbd0694a4e9bdc8-Paper.pdf) (MLSys 2022)
* **[GPSConv](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GPSConv)** from Rampášek *et al.*: [Recipe for a General, Powerful, Scalable Graph Transformer](https://arxiv.org/abs/2205.12454) (NeurIPS 2022) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_gps.py)]
</details>

**Pooling layers:**
Expand Down
28 changes: 28 additions & 0 deletions test/nn/conv/test_gps_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pytest
import torch
from torch_sparse import SparseTensor

from torch_geometric.nn import GPSConv, SAGEConv


@pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm'])
def test_gps_conv(norm):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 1, 0, 2, 3, 2, 3],
[1, 0, 0, 1, 3, 2, 3, 2]])
row, col = edge_index
adj_t = SparseTensor(row=col, col=row, sparse_sizes=(4, 4))
batch = torch.tensor([0, 0, 1, 1])

conv = GPSConv(16, conv=SAGEConv(16, 16), heads=4, norm=norm)
conv.reset_parameters()
assert str(conv) == ('GPSConv(16, conv=SAGEConv(16, 16, aggr=mean), '
'heads=4)')

out = conv(x, edge_index)
assert out.size() == (4, 16)
assert torch.allclose(conv(x, adj_t), out)

out = conv(x, edge_index, batch)
assert out.size() == (4, 16)
assert torch.allclose(conv(x, adj_t, batch), out)
10 changes: 5 additions & 5 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def test_my_conv():
assert out.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out)
assert torch.allclose(conv(x1, adj.t()), out)
assert torch.allclose(conv(x1, torch_adj.t()), out)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv(x1, adj.t()), out)
assert torch.allclose(conv(x1, torch_adj.t()), out)
assert torch.allclose(conv(x1, torch_adj.t()), out, atol=1e-6)
conv.fuse = True

adj = adj.sparse_resize((4, 2))
Expand All @@ -78,14 +78,14 @@ def test_my_conv():
assert out2.size() == (2, 32)
assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)
conv.fuse = False
assert torch.allclose(conv((x1, x2), adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1)
assert torch.allclose(conv((x1, x2), torch_adj.t()), out1, atol=1e-6)
assert torch.allclose(conv((x1, None), adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj.t()), out2)
assert torch.allclose(conv((x1, None), torch_adj.t()), out2, atol=1e-6)
conv.fuse = True

# Test backward compatibility for `torch.sparse` tensors:
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from .lg_conv import LGConv
from .ssg_conv import SSGConv
from .point_gnn_conv import PointGNNConv
from .gps_conv import GPSConv

__all__ = [
'MessagePassing',
Expand Down Expand Up @@ -115,6 +116,7 @@
'HANConv',
'LGConv',
'PointGNNConv',
'GPSConv',
]

classes = __all__
162 changes: 162 additions & 0 deletions torch_geometric/nn/conv/gps_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import inspect
from typing import Any, Dict, Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Dropout, Linear, Sequential

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.resolver import (
activation_resolver,
normalization_resolver,
)
from torch_geometric.typing import Adj
from torch_geometric.utils import to_dense_batch

from ..inits import reset


class GPSConv(torch.nn.Module):
r"""The general, powerful, scalable (GPS) graph transformer layer from the
`"Recipe for a General, Powerful, Scalable Graph Transformer"
<https://arxiv.org/abs/2205.12454>`_ paper.
The GPS layer is based on a 3-part recipe:
1. Inclusion of positional (PE) and structural encodings (SE) to the input
features (done in a pre-processing step via
:class:`torch_geometric.transforms`).
2. A local message passing layer (MPNN) that operates on the input graph.
3. A global attention layer that operates on the entire graph.
.. note::
For an example of using :class:`GPSConv`, see
`examples/graph_gps.py
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
graph_gps.py>`_.
Args:
channels (int): Size of each input sample.
conv (MessagePassing, optional): The local message passing layer.
heads (int, optional): Number of multi-head-attentions.
(default: :obj:`1`)
dropout (float, optional): Dropout probability of intermediate
embeddings. (default: :obj:`0.`)
attn_dropout (float, optional): Dropout probability of the normalized
attention coefficients. (default: :obj:`0`)
act (str or Callable, optional): The non-linear activation function to
use. (default: :obj:`"relu"`)
act_kwargs (Dict[str, Any], optional): Arguments passed to the
respective activation function defined by :obj:`act`.
(default: :obj:`None`)
norm (str or Callable, optional): The normalization function to
use. (default: :obj:`"batch_norm"`)
norm_kwargs (Dict[str, Any], optional): Arguments passed to the
respective normalization function defined by :obj:`norm`.
(default: :obj:`None`)
"""
def __init__(
self,
channels: int,
conv: Optional[MessagePassing],
heads: int = 1,
dropout: float = 0.0,
attn_dropout: float = 0.0,
act: str = 'relu',
act_kwargs: Optional[Dict[str, Any]] = None,
norm: Optional[str] = 'batch_norm',
norm_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__()

self.channels = channels
self.conv = conv
self.heads = heads
self.dropout = dropout

self.attn = torch.nn.MultiheadAttention(
channels,
heads,
dropout=attn_dropout,
batch_first=True,
)

self.mlp = Sequential(
Linear(channels, channels * 2),
activation_resolver(act, **(act_kwargs or {})),
Dropout(dropout),
Linear(channels * 2, channels),
Dropout(dropout),
)

norm_kwargs = norm_kwargs or {}
self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)
self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)
self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)

self.norm_with_batch = False
if self.norm1 is not None:
signature = inspect.signature(self.norm1.forward)
self.norm_with_batch = 'batch' in signature.parameters

def reset_parameters(self):
if self.conv is not None:
self.conv.reset_parameters()
self.attn._reset_parameters()
reset(self.mlp)
if self.norm1 is not None:
self.norm1.reset_parameters()
if self.norm2 is not None:
self.norm2.reset_parameters()
if self.norm3 is not None:
self.norm3.reset_parameters()

def forward(
self,
x: Tensor,
edge_index: Adj,
batch: Optional[torch.Tensor] = None,
**kwargs,
) -> Tensor:
""""""
hs = []
if self.conv is not None: # Local MPNN.
h = self.conv(x, edge_index, **kwargs)
h = F.dropout(h, p=self.dropout, training=self.training)
h = h + x
if self.norm1 is not None:
if self.norm_with_batch:
h = self.norm1(h, batch=batch)
else:
h = self.norm1(h)
hs.append(h)

# Global attention transformer-style model.
h, mask = to_dense_batch(x, batch)
h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False)
h = h[mask]
h = F.dropout(h, p=self.dropout, training=self.training)
h = h + x # Residual connection.
if self.norm2 is not None:
if self.norm_with_batch:
h = self.norm2(h, batch=batch)
else:
h = self.norm2(h)
hs.append(h)

out = sum(hs) # Combine local and global outputs.

out = out + self.mlp(out)
if self.norm3 is not None:
if self.norm_with_batch:
h = self.norm3(h, batch=batch)
else:
h = self.norm3(h)

return out

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.channels}, '
f'conv={self.conv}, heads={self.heads})')

0 comments on commit 78bba36

Please sign in to comment.