-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GPSConv
Graph Transformer Layer (#6326)
- Loading branch information
Showing
6 changed files
with
199 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import pytest | ||
import torch | ||
from torch_sparse import SparseTensor | ||
|
||
from torch_geometric.nn import GPSConv, SAGEConv | ||
|
||
|
||
@pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm']) | ||
def test_gps_conv(norm): | ||
x = torch.randn(4, 16) | ||
edge_index = torch.tensor([[0, 1, 1, 0, 2, 3, 2, 3], | ||
[1, 0, 0, 1, 3, 2, 3, 2]]) | ||
row, col = edge_index | ||
adj_t = SparseTensor(row=col, col=row, sparse_sizes=(4, 4)) | ||
batch = torch.tensor([0, 0, 1, 1]) | ||
|
||
conv = GPSConv(16, conv=SAGEConv(16, 16), heads=4, norm=norm) | ||
conv.reset_parameters() | ||
assert str(conv) == ('GPSConv(16, conv=SAGEConv(16, 16, aggr=mean), ' | ||
'heads=4)') | ||
|
||
out = conv(x, edge_index) | ||
assert out.size() == (4, 16) | ||
assert torch.allclose(conv(x, adj_t), out) | ||
|
||
out = conv(x, edge_index, batch) | ||
assert out.size() == (4, 16) | ||
assert torch.allclose(conv(x, adj_t, batch), out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import inspect | ||
from typing import Any, Dict, Optional | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import Tensor | ||
from torch.nn import Dropout, Linear, Sequential | ||
|
||
from torch_geometric.nn.conv import MessagePassing | ||
from torch_geometric.nn.resolver import ( | ||
activation_resolver, | ||
normalization_resolver, | ||
) | ||
from torch_geometric.typing import Adj | ||
from torch_geometric.utils import to_dense_batch | ||
|
||
from ..inits import reset | ||
|
||
|
||
class GPSConv(torch.nn.Module): | ||
r"""The general, powerful, scalable (GPS) graph transformer layer from the | ||
`"Recipe for a General, Powerful, Scalable Graph Transformer" | ||
<https://arxiv.org/abs/2205.12454>`_ paper. | ||
The GPS layer is based on a 3-part recipe: | ||
1. Inclusion of positional (PE) and structural encodings (SE) to the input | ||
features (done in a pre-processing step via | ||
:class:`torch_geometric.transforms`). | ||
2. A local message passing layer (MPNN) that operates on the input graph. | ||
3. A global attention layer that operates on the entire graph. | ||
.. note:: | ||
For an example of using :class:`GPSConv`, see | ||
`examples/graph_gps.py | ||
<https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ | ||
graph_gps.py>`_. | ||
Args: | ||
channels (int): Size of each input sample. | ||
conv (MessagePassing, optional): The local message passing layer. | ||
heads (int, optional): Number of multi-head-attentions. | ||
(default: :obj:`1`) | ||
dropout (float, optional): Dropout probability of intermediate | ||
embeddings. (default: :obj:`0.`) | ||
attn_dropout (float, optional): Dropout probability of the normalized | ||
attention coefficients. (default: :obj:`0`) | ||
act (str or Callable, optional): The non-linear activation function to | ||
use. (default: :obj:`"relu"`) | ||
act_kwargs (Dict[str, Any], optional): Arguments passed to the | ||
respective activation function defined by :obj:`act`. | ||
(default: :obj:`None`) | ||
norm (str or Callable, optional): The normalization function to | ||
use. (default: :obj:`"batch_norm"`) | ||
norm_kwargs (Dict[str, Any], optional): Arguments passed to the | ||
respective normalization function defined by :obj:`norm`. | ||
(default: :obj:`None`) | ||
""" | ||
def __init__( | ||
self, | ||
channels: int, | ||
conv: Optional[MessagePassing], | ||
heads: int = 1, | ||
dropout: float = 0.0, | ||
attn_dropout: float = 0.0, | ||
act: str = 'relu', | ||
act_kwargs: Optional[Dict[str, Any]] = None, | ||
norm: Optional[str] = 'batch_norm', | ||
norm_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
super().__init__() | ||
|
||
self.channels = channels | ||
self.conv = conv | ||
self.heads = heads | ||
self.dropout = dropout | ||
|
||
self.attn = torch.nn.MultiheadAttention( | ||
channels, | ||
heads, | ||
dropout=attn_dropout, | ||
batch_first=True, | ||
) | ||
|
||
self.mlp = Sequential( | ||
Linear(channels, channels * 2), | ||
activation_resolver(act, **(act_kwargs or {})), | ||
Dropout(dropout), | ||
Linear(channels * 2, channels), | ||
Dropout(dropout), | ||
) | ||
|
||
norm_kwargs = norm_kwargs or {} | ||
self.norm1 = normalization_resolver(norm, channels, **norm_kwargs) | ||
self.norm2 = normalization_resolver(norm, channels, **norm_kwargs) | ||
self.norm3 = normalization_resolver(norm, channels, **norm_kwargs) | ||
|
||
self.norm_with_batch = False | ||
if self.norm1 is not None: | ||
signature = inspect.signature(self.norm1.forward) | ||
self.norm_with_batch = 'batch' in signature.parameters | ||
|
||
def reset_parameters(self): | ||
if self.conv is not None: | ||
self.conv.reset_parameters() | ||
self.attn._reset_parameters() | ||
reset(self.mlp) | ||
if self.norm1 is not None: | ||
self.norm1.reset_parameters() | ||
if self.norm2 is not None: | ||
self.norm2.reset_parameters() | ||
if self.norm3 is not None: | ||
self.norm3.reset_parameters() | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
edge_index: Adj, | ||
batch: Optional[torch.Tensor] = None, | ||
**kwargs, | ||
) -> Tensor: | ||
"""""" | ||
hs = [] | ||
if self.conv is not None: # Local MPNN. | ||
h = self.conv(x, edge_index, **kwargs) | ||
h = F.dropout(h, p=self.dropout, training=self.training) | ||
h = h + x | ||
if self.norm1 is not None: | ||
if self.norm_with_batch: | ||
h = self.norm1(h, batch=batch) | ||
else: | ||
h = self.norm1(h) | ||
hs.append(h) | ||
|
||
# Global attention transformer-style model. | ||
h, mask = to_dense_batch(x, batch) | ||
h, _ = self.attn(h, h, h, key_padding_mask=~mask, need_weights=False) | ||
h = h[mask] | ||
h = F.dropout(h, p=self.dropout, training=self.training) | ||
h = h + x # Residual connection. | ||
if self.norm2 is not None: | ||
if self.norm_with_batch: | ||
h = self.norm2(h, batch=batch) | ||
else: | ||
h = self.norm2(h) | ||
hs.append(h) | ||
|
||
out = sum(hs) # Combine local and global outputs. | ||
|
||
out = out + self.mlp(out) | ||
if self.norm3 is not None: | ||
if self.norm_with_batch: | ||
h = self.norm3(h, batch=batch) | ||
else: | ||
h = self.norm3(h) | ||
|
||
return out | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}({self.channels}, ' | ||
f'conv={self.conv}, heads={self.heads})') |