diff --git a/CHANGELOG.md b/CHANGELOG.md index 10615939597c..1c3142d2b183 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124)) - Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117)) ### Changed +- Breaking Change: Changed the interface and implementation of `GraphMultisetTransformer` ([#6343](https://github.com/pyg-team/pytorch_geometric/pull/6343)) - Fixed the approximate PPR variant in `transforms.GDC` to not crash on graphs with isolated nodes ([#6242](https://github.com/pyg-team/pytorch_geometric/pull/6242)) - Added a warning when accesing `InMemoryDataset.data` ([#6318](https://github.com/pyg-team/pytorch_geometric/pull/6318)) - Drop `SparseTensor` dependency in `GraphStore` ([#5517](https://github.com/pyg-team/pytorch_geometric/pull/5517)) diff --git a/examples/proteins_gmt.py b/examples/proteins_gmt.py index c13a8522a2f3..0667e93d2785 100644 --- a/examples/proteins_gmt.py +++ b/examples/proteins_gmt.py @@ -10,14 +10,15 @@ path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS') dataset = TUDataset(path, name='PROTEINS').shuffle() -avg_num_nodes = int(dataset.data.num_nodes / len(dataset)) + n = (len(dataset) + 9) // 10 -test_dataset = dataset[:n] -val_dataset = dataset[n:2 * n] train_dataset = dataset[2 * n:] -test_loader = DataLoader(test_dataset, batch_size=128) -val_loader = DataLoader(val_dataset, batch_size=128) +val_dataset = dataset[n:2 * n] +test_dataset = dataset[:n] + train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) +val_loader = DataLoader(val_dataset, batch_size=128) +test_loader = DataLoader(test_dataset, batch_size=128) class Net(torch.nn.Module): @@ -28,19 +29,9 @@ def __init__(self): self.conv2 = GCNConv(32, 32) self.conv3 = GCNConv(32, 32) - self.pool = GraphMultisetTransformer( - in_channels=96, - hidden_channels=64, - out_channels=32, - Conv=GCNConv, - num_nodes=avg_num_nodes, - pooling_ratio=0.25, - pool_sequences=['GMPool_G', 'SelfAtt', 'GMPool_I'], - num_heads=4, - layer_norm=False, - ) - - self.lin1 = Linear(32, 16) + self.pool = GraphMultisetTransformer(96, k=10, heads=4) + + self.lin1 = Linear(96, 16) self.lin2 = Linear(16, dataset.num_classes) def forward(self, x0, edge_index, batch): @@ -49,13 +40,13 @@ def forward(self, x0, edge_index, batch): x3 = self.conv3(x2, edge_index).relu() x = torch.cat([x1, x2, x3], dim=-1) - x = self.pool(x, batch, edge_index) + x = self.pool(x, batch) x = self.lin1(x).relu() x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) - return x.log_softmax(dim=-1) + return x device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -71,7 +62,7 @@ def train(): data = data.to(device) optimizer.zero_grad() out = model(data.x, data.edge_index, data.batch) - loss = F.nll_loss(out, data.y) + loss = F.cross_entropy(out, data.y) loss.backward() total_loss += data.num_graphs * float(loss) optimizer.step() diff --git a/test/nn/aggr/test_gmt.py b/test/nn/aggr/test_gmt.py index cce7d17c3ddb..235465604de3 100644 --- a/test/nn/aggr/test_gmt.py +++ b/test/nn/aggr/test_gmt.py @@ -1,97 +1,21 @@ -import pytest import torch -from torch_geometric.nn import GATConv, GCNConv, GraphConv from torch_geometric.nn.aggr import GraphMultisetTransformer +from torch_geometric.testing import is_full_test -@pytest.mark.parametrize('layer_norm', [False, True]) -def test_graph_multiset_transformer(layer_norm): - num_avg_nodes = 4 - in_channels, hidden_channels, out_channels = 32, 64, 16 +def test_graph_multiset_transformer(): + x = torch.randn(6, 16) + index = torch.tensor([0, 0, 1, 1, 1, 2]) - x = torch.randn((6, in_channels)) - edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], - [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) - index = torch.tensor([0, 0, 0, 0, 1, 1]) + aggr = GraphMultisetTransformer(16, k=2, heads=2) + aggr.reset_parameters() + assert str(aggr) == ('GraphMultisetTransformer(16, k=2, heads=2, ' + 'layer_norm=False)') - for GNN in [GraphConv, GCNConv, GATConv]: - gmt = GraphMultisetTransformer( - in_channels, - hidden_channels, - out_channels, - Conv=GNN, - num_nodes=num_avg_nodes, - pooling_ratio=0.25, - pool_sequences=['GMPool_I'], - num_heads=4, - layer_norm=layer_norm, - ) - gmt.reset_parameters() - assert str(gmt) == ("GraphMultisetTransformer(32, 16, " - "pool_sequences=['GMPool_I'])") - assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels) + out = aggr(x, index) + assert out.size() == (3, 16) - gmt = GraphMultisetTransformer( - in_channels, - hidden_channels, - out_channels, - Conv=GNN, - num_nodes=num_avg_nodes, - pooling_ratio=0.25, - pool_sequences=['GMPool_G'], - num_heads=4, - layer_norm=layer_norm, - ) - gmt.reset_parameters() - assert str(gmt) == ("GraphMultisetTransformer(32, 16, " - "pool_sequences=['GMPool_G'])") - assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels) - - gmt = GraphMultisetTransformer( - in_channels, - hidden_channels, - out_channels, - Conv=GNN, - num_nodes=num_avg_nodes, - pooling_ratio=0.25, - pool_sequences=['GMPool_G', 'GMPool_I'], - num_heads=4, - layer_norm=layer_norm, - ) - gmt.reset_parameters() - assert str(gmt) == ("GraphMultisetTransformer(32, 16, " - "pool_sequences=['GMPool_G', 'GMPool_I'])") - assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels) - - gmt = GraphMultisetTransformer( - in_channels, - hidden_channels, - out_channels, - Conv=GNN, - num_nodes=num_avg_nodes, - pooling_ratio=0.25, - pool_sequences=['GMPool_G', 'SelfAtt', 'GMPool_I'], - num_heads=4, - layer_norm=layer_norm, - ) - gmt.reset_parameters() - assert str(gmt) == ("GraphMultisetTransformer(32, 16, pool_sequences=" - "['GMPool_G', 'SelfAtt', 'GMPool_I'])") - assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels) - - gmt = GraphMultisetTransformer( - in_channels, - hidden_channels, - out_channels, - Conv=GNN, - num_nodes=num_avg_nodes, - pooling_ratio=0.25, - pool_sequences=['GMPool_G', 'SelfAtt', 'SelfAtt', 'GMPool_I'], - num_heads=4, - layer_norm=layer_norm, - ) - gmt.reset_parameters() - assert str(gmt) == ("GraphMultisetTransformer(32, 16, pool_sequences=" - "['GMPool_G', 'SelfAtt', 'SelfAtt', 'GMPool_I'])") - assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels) + if is_full_test(): + jit = torch.jit.script(aggr) + assert torch.allclose(jit(x, index), out) diff --git a/torch_geometric/nn/aggr/gmt.py b/torch_geometric/nn/aggr/gmt.py index b61a8945db17..21ca7b8599d1 100644 --- a/torch_geometric/nn/aggr/gmt.py +++ b/torch_geometric/nn/aggr/gmt.py @@ -1,136 +1,13 @@ -import math -from typing import List, Optional, Tuple, Type +from typing import Optional import torch from torch import Tensor -from torch.nn import LayerNorm, Linear from torch_geometric.nn.aggr import Aggregation -from torch_geometric.utils import to_dense_batch - - -class MAB(torch.nn.Module): - r"""Multihead-Attention Block.""" - def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, - Conv: Optional[Type] = None, layer_norm: bool = False): - - super().__init__() - self.dim_V = dim_V - self.num_heads = num_heads - self.layer_norm = layer_norm - - self.fc_q = Linear(dim_Q, dim_V) - - if Conv is None: - self.layer_k = Linear(dim_K, dim_V) - self.layer_v = Linear(dim_K, dim_V) - else: - self.layer_k = Conv(dim_K, dim_V) - self.layer_v = Conv(dim_K, dim_V) - - if layer_norm: - self.ln0 = LayerNorm(dim_V) - self.ln1 = LayerNorm(dim_V) - - self.fc_o = Linear(dim_V, dim_V) - - def reset_parameters(self): - self.fc_q.reset_parameters() - self.layer_k.reset_parameters() - self.layer_v.reset_parameters() - if self.layer_norm: - self.ln0.reset_parameters() - self.ln1.reset_parameters() - self.fc_o.reset_parameters() - pass - - def forward( - self, - Q: Tensor, - K: Tensor, - graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, - mask: Optional[Tensor] = None, - ) -> Tensor: - - Q = self.fc_q(Q) - - if graph is not None: - x, edge_index, batch = graph - K, V = self.layer_k(x, edge_index), self.layer_v(x, edge_index) - K, _ = to_dense_batch(K, batch) - V, _ = to_dense_batch(V, batch) - else: - K, V = self.layer_k(K), self.layer_v(K) - - dim_split = self.dim_V // self.num_heads - Q_ = torch.cat(Q.split(dim_split, 2), dim=0) - K_ = torch.cat(K.split(dim_split, 2), dim=0) - V_ = torch.cat(V.split(dim_split, 2), dim=0) - - if mask is not None: - mask = torch.cat([mask for _ in range(self.num_heads)], 0) - attention_score = Q_.bmm(K_.transpose(1, 2)) - attention_score = attention_score / math.sqrt(self.dim_V) - A = torch.softmax(mask + attention_score, 1) - else: - A = torch.softmax( - Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_V), 2) - - out = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) - - if self.layer_norm: - out = self.ln0(out) - - out = out + self.fc_o(out).relu() - - if self.layer_norm: - out = self.ln1(out) - - return out - - -class SAB(torch.nn.Module): - r"""Self-Attention Block.""" - def __init__(self, in_channels: int, out_channels: int, num_heads: int, - Conv: Optional[Type] = None, layer_norm: bool = False): - super().__init__() - self.mab = MAB(in_channels, in_channels, out_channels, num_heads, - Conv=Conv, layer_norm=layer_norm) - - def reset_parameters(self): - self.mab.reset_parameters() - - def forward( - self, - x: Tensor, - graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, - mask: Optional[Tensor] = None, - ) -> Tensor: - return self.mab(x, x, graph, mask) - - -class PMA(torch.nn.Module): - r"""Graph pooling with Multihead-Attention.""" - def __init__(self, channels: int, num_heads: int, num_seeds: int, - Conv: Optional[Type] = None, layer_norm: bool = False): - super().__init__() - self.S = torch.nn.Parameter(torch.Tensor(1, num_seeds, channels)) - self.mab = MAB(channels, channels, channels, num_heads, Conv=Conv, - layer_norm=layer_norm) - - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.xavier_uniform_(self.S) - self.mab.reset_parameters() - - def forward( - self, - x: Tensor, - graph: Optional[Tuple[Tensor, Tensor, Tensor]] = None, - mask: Optional[Tensor] = None, - ) -> Tensor: - return self.mab(self.S.repeat(x.size(0), 1, 1), x, graph, mask) +from torch_geometric.nn.aggr.utils import ( + PoolingByMultiheadAttention, + SetAttentionBlock, +) class GraphMultisetTransformer(Aggregation): @@ -138,115 +15,66 @@ class GraphMultisetTransformer(Aggregation): `"Accurate Learning of Graph Representations with Graph Multiset Pooling" `_ paper. - The Graph Multiset Transformer clusters nodes of the entire graph via - attention-based pooling operations (:obj:`"GMPool_G"` or - :obj:`"GMPool_I"`). - In addition, self-attention (:obj:`"SelfAtt"`) can be used to calculate - the inter-relationships among nodes. + The :class:`GraphMultisetTransformer` aggregates elements into + :math:`k` representative elements via attention-based pooling, computes the + interaction among them via :obj:`num_encoder_blocks` self-attention blocks, + and finally pools the representative elements via attention-based pooling + into a single cluster. Args: - in_channels (int): Size of each input sample. - hidden_channels (int): Size of each hidden sample. - out_channels (int): Size of each output sample. - conv (Type, optional): A graph neural network layer - for calculating hidden representations of nodes for - :obj:`"GMPool_G"` (one of - :class:`~torch_geometric.nn.conv.GCNConv`, - :class:`~torch_geometric.nn.conv.GraphConv` or - :class:`~torch_geometric.nn.conv.GATConv`). - (default: :class:`~torch_geometric.nn.conv.GCNConv`) - num_nodes (int, optional): The number of average - or maximum nodes. (default: :obj:`300`) - pooling_ratio (float, optional): Graph pooling ratio - for each pooling. (default: :obj:`0.25`) - pool_sequences ([str], optional): A sequence of pooling layers - consisting of Graph Multiset Transformer submodules (one of - :obj:`["GMPool_I"]`, - :obj:`["GMPool_G"]`, - :obj:`["GMPool_G", "GMPool_I"]`, - :obj:`["GMPool_G", "SelfAtt", "GMPool_I"]` or - :obj:`["GMPool_G", "SelfAtt", "SelfAtt", "GMPool_I"]`). - (default: :obj:`["GMPool_G", "SelfAtt", "GMPool_I"]`) - num_heads (int, optional): Number of attention heads. - (default: :obj:`4`) - layer_norm (bool, optional): If set to :obj:`True`, will make use of - layer normalization. (default: :obj:`False`) + channels (int): Size of each input sample. + k (int): Number of :math:`k` representative nodes after pooling. + num_encoder_blocks (int, optional): Number of Set Attention Blocks + (SABs) between the two pooling blocks. (default: :obj:`1`) + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + norm (str, optional): If set to :obj:`True`, will apply layer + normalization. (default: :obj:`False`) """ def __init__( self, - in_channels: int, - hidden_channels: int, - out_channels: int, - Conv: Optional[Type] = None, - num_nodes: int = 300, - pooling_ratio: float = 0.25, - pool_sequences: List[str] = ['GMPool_G', 'SelfAtt', 'GMPool_I'], - num_heads: int = 4, + channels: int, + k: int, + num_encoder_blocks: int = 1, + heads: int = 1, layer_norm: bool = False, ): - from torch_geometric.nn import GCNConv # noqa: avoid circular import super().__init__() - self.in_channels = in_channels - self.hidden_channels = hidden_channels - self.out_channels = out_channels - self.Conv = Conv or GCNConv - self.num_nodes = num_nodes - self.pooling_ratio = pooling_ratio - self.pool_sequences = pool_sequences - self.num_heads = num_heads - self.layer_norm = layer_norm - - self.lin1 = Linear(in_channels, hidden_channels) - self.lin2 = Linear(hidden_channels, out_channels) - - self.pools = torch.nn.ModuleList() - num_out_nodes = math.ceil(num_nodes * pooling_ratio) - for i, pool_type in enumerate(pool_sequences): - if pool_type not in ['GMPool_G', 'GMPool_I', 'SelfAtt']: - raise ValueError("Elements in 'pool_sequences' should be one " - "of 'GMPool_G', 'GMPool_I', or 'SelfAtt'") - if i == len(pool_sequences) - 1: - num_out_nodes = 1 - - if pool_type == 'GMPool_G': - self.pools.append( - PMA(hidden_channels, num_heads, num_out_nodes, - Conv=self.Conv, layer_norm=layer_norm)) - num_out_nodes = math.ceil(num_out_nodes * self.pooling_ratio) - - elif pool_type == 'GMPool_I': - self.pools.append( - PMA(hidden_channels, num_heads, num_out_nodes, Conv=None, - layer_norm=layer_norm)) - num_out_nodes = math.ceil(num_out_nodes * self.pooling_ratio) + self.channels = channels + self.k = k + self.heads = heads + self.layer_norm = layer_norm - elif pool_type == 'SelfAtt': - self.pools.append( - SAB(hidden_channels, hidden_channels, num_heads, Conv=None, - layer_norm=layer_norm)) + self.pma1 = PoolingByMultiheadAttention(channels, k, heads, layer_norm) + self.encoders = torch.nn.ModuleList([ + SetAttentionBlock(channels, heads, layer_norm) + for _ in range(num_encoder_blocks) + ]) + self.pma2 = PoolingByMultiheadAttention(channels, 1, heads, layer_norm) def reset_parameters(self): - self.lin1.reset_parameters() - self.lin2.reset_parameters() - for pool in self.pools: - pool.reset_parameters() + self.pma1.reset_parameters() + for encoder in self.encoders: + encoder.reset_parameters() + self.pma2.reset_parameters() def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, - dim: int = -2, edge_index: Optional[Tensor] = None) -> Tensor: + dim: int = -2) -> Tensor: + + x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim) + + x = self.pma1(x, mask) - x = self.lin1(x) - batch_x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim) - mask = (~mask).unsqueeze(1).to(dtype=x.dtype) * -1e9 + for encoder in self.encoders: + x = encoder(x) - for i, (name, pool) in enumerate(zip(self.pool_sequences, self.pools)): - graph = (x, edge_index, index) if name == 'GMPool_G' else None - batch_x = pool(batch_x, graph, mask) - mask = None + x = self.pma2(x) - return self.lin2(batch_x.squeeze(1)) + return x.squeeze(1) def __repr__(self) -> str: - return (f'{self.__class__.__name__}({self.in_channels}, ' - f'{self.out_channels}, pool_sequences={self.pool_sequences})') + return (f'{self.__class__.__name__}({self.channels}, ' + f'k={self.k}, heads={self.heads}, ' + f'layer_norm={self.layer_norm})') diff --git a/torch_geometric/nn/aggr/set_transformer.py b/torch_geometric/nn/aggr/set_transformer.py index 2ed6525efbc2..43c9263bd2ac 100644 --- a/torch_geometric/nn/aggr/set_transformer.py +++ b/torch_geometric/nn/aggr/set_transformer.py @@ -29,7 +29,7 @@ class SetTransformerAggregation(Aggregation): concat (bool, optional): If set to :obj:`False`, the seed embeddings are averaged instead of concatenated. (default: :obj:`True`) norm (str, optional): If set to :obj:`True`, will apply layer - normalization. (default: :obj:`True`) + normalization. (default: :obj:`False`) """ def __init__( self, diff --git a/torch_geometric/nn/glob.py b/torch_geometric/nn/glob.py index d499e7574f79..c0092ff30e34 100644 --- a/torch_geometric/nn/glob.py +++ b/torch_geometric/nn/glob.py @@ -4,22 +4,7 @@ global_max_pool, global_mean_pool, ) -from torch_geometric.nn.aggr import ( - AttentionalAggregation, - GraphMultisetTransformer, - Set2Set, - SortAggregation, -) - -Set2Set = deprecated( - details="use 'nn.aggr.Set2Set' instead", - func_name='nn.glob.Set2Set', -)(Set2Set) - -GraphMultisetTransformer = deprecated( - details="use 'nn.aggr.GraphMultisetTransformer' instead", - func_name='nn.glob.GraphMultisetTransformer', -)(GraphMultisetTransformer) +from torch_geometric.nn.aggr import AttentionalAggregation, SortAggregation @deprecated(