Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-work GraphMultisetTransformer #6343

Merged
merged 2 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))
- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))
### Changed
- Breaking Change: Changed the interface and implementation of `GraphMultisetTransformer` ([#6343](https://github.com/pyg-team/pytorch_geometric/pull/6343))
- Fixed the approximate PPR variant in `transforms.GDC` to not crash on graphs with isolated nodes ([#6242](https://github.com/pyg-team/pytorch_geometric/pull/6242))
- Added a warning when accesing `InMemoryDataset.data` ([#6318](https://github.com/pyg-team/pytorch_geometric/pull/6318))
- Drop `SparseTensor` dependency in `GraphStore` ([#5517](https://github.com/pyg-team/pytorch_geometric/pull/5517))
Expand Down
33 changes: 12 additions & 21 deletions examples/proteins_gmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')
dataset = TUDataset(path, name='PROTEINS').shuffle()
avg_num_nodes = int(dataset.data.num_nodes / len(dataset))

n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DataLoader(test_dataset, batch_size=128)
val_loader = DataLoader(val_dataset, batch_size=128)
val_dataset = dataset[n:2 * n]
test_dataset = dataset[:n]

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)
test_loader = DataLoader(test_dataset, batch_size=128)


class Net(torch.nn.Module):
Expand All @@ -28,19 +29,9 @@ def __init__(self):
self.conv2 = GCNConv(32, 32)
self.conv3 = GCNConv(32, 32)

self.pool = GraphMultisetTransformer(
in_channels=96,
hidden_channels=64,
out_channels=32,
Conv=GCNConv,
num_nodes=avg_num_nodes,
pooling_ratio=0.25,
pool_sequences=['GMPool_G', 'SelfAtt', 'GMPool_I'],
num_heads=4,
layer_norm=False,
)

self.lin1 = Linear(32, 16)
self.pool = GraphMultisetTransformer(96, k=10, heads=4)

self.lin1 = Linear(96, 16)
self.lin2 = Linear(16, dataset.num_classes)

def forward(self, x0, edge_index, batch):
Expand All @@ -49,13 +40,13 @@ def forward(self, x0, edge_index, batch):
x3 = self.conv3(x2, edge_index).relu()
x = torch.cat([x1, x2, x3], dim=-1)

x = self.pool(x, batch, edge_index)
x = self.pool(x, batch)

x = self.lin1(x).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin2(x)

return x.log_softmax(dim=-1)
return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand All @@ -71,7 +62,7 @@ def train():
data = data.to(device)
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y)
loss = F.cross_entropy(out, data.y)
loss.backward()
total_loss += data.num_graphs * float(loss)
optimizer.step()
Expand Down
102 changes: 13 additions & 89 deletions test/nn/aggr/test_gmt.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,21 @@
import pytest
import torch

from torch_geometric.nn import GATConv, GCNConv, GraphConv
from torch_geometric.nn.aggr import GraphMultisetTransformer
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('layer_norm', [False, True])
def test_graph_multiset_transformer(layer_norm):
num_avg_nodes = 4
in_channels, hidden_channels, out_channels = 32, 64, 16
def test_graph_multiset_transformer():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])

x = torch.randn((6, in_channels))
edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],
[1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])
index = torch.tensor([0, 0, 0, 0, 1, 1])
aggr = GraphMultisetTransformer(16, k=2, heads=2)
aggr.reset_parameters()
assert str(aggr) == ('GraphMultisetTransformer(16, k=2, heads=2, '
'layer_norm=False)')

for GNN in [GraphConv, GCNConv, GATConv]:
gmt = GraphMultisetTransformer(
in_channels,
hidden_channels,
out_channels,
Conv=GNN,
num_nodes=num_avg_nodes,
pooling_ratio=0.25,
pool_sequences=['GMPool_I'],
num_heads=4,
layer_norm=layer_norm,
)
gmt.reset_parameters()
assert str(gmt) == ("GraphMultisetTransformer(32, 16, "
"pool_sequences=['GMPool_I'])")
assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels)
out = aggr(x, index)
assert out.size() == (3, 16)

gmt = GraphMultisetTransformer(
in_channels,
hidden_channels,
out_channels,
Conv=GNN,
num_nodes=num_avg_nodes,
pooling_ratio=0.25,
pool_sequences=['GMPool_G'],
num_heads=4,
layer_norm=layer_norm,
)
gmt.reset_parameters()
assert str(gmt) == ("GraphMultisetTransformer(32, 16, "
"pool_sequences=['GMPool_G'])")
assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels)

gmt = GraphMultisetTransformer(
in_channels,
hidden_channels,
out_channels,
Conv=GNN,
num_nodes=num_avg_nodes,
pooling_ratio=0.25,
pool_sequences=['GMPool_G', 'GMPool_I'],
num_heads=4,
layer_norm=layer_norm,
)
gmt.reset_parameters()
assert str(gmt) == ("GraphMultisetTransformer(32, 16, "
"pool_sequences=['GMPool_G', 'GMPool_I'])")
assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels)

gmt = GraphMultisetTransformer(
in_channels,
hidden_channels,
out_channels,
Conv=GNN,
num_nodes=num_avg_nodes,
pooling_ratio=0.25,
pool_sequences=['GMPool_G', 'SelfAtt', 'GMPool_I'],
num_heads=4,
layer_norm=layer_norm,
)
gmt.reset_parameters()
assert str(gmt) == ("GraphMultisetTransformer(32, 16, pool_sequences="
"['GMPool_G', 'SelfAtt', 'GMPool_I'])")
assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels)

gmt = GraphMultisetTransformer(
in_channels,
hidden_channels,
out_channels,
Conv=GNN,
num_nodes=num_avg_nodes,
pooling_ratio=0.25,
pool_sequences=['GMPool_G', 'SelfAtt', 'SelfAtt', 'GMPool_I'],
num_heads=4,
layer_norm=layer_norm,
)
gmt.reset_parameters()
assert str(gmt) == ("GraphMultisetTransformer(32, 16, pool_sequences="
"['GMPool_G', 'SelfAtt', 'SelfAtt', 'GMPool_I'])")
assert gmt(x, index, edge_index=edge_index).size() == (2, out_channels)
if is_full_test():
jit = torch.jit.script(aggr)
assert torch.allclose(jit(x, index), out)
Loading