Skip to content

Commit

Permalink
Add MixHopConv to torch_geometric.nn.conv (#8025)
Browse files Browse the repository at this point in the history
From #8022

- Add new operator `MixHopConv` in `nn.conv`
- Add a test for it
- Add an example for it

Feel free to comment, thanks:)

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
2 people authored and JakubPietrakIntel committed Sep 27, 2023
1 parent f06cbcd commit 3cb3c40
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024))
- Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230))
- Added the `NeuralFingerprint` model for learning fingerprints of molecules ([#7919](https://github.com/pyg-team/pytorch_geometric/pull/7919))
Expand Down
2 changes: 1 addition & 1 deletion examples/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test():
return accs


best_val_acc = final_test_acc = 0
best_val_acc = test_acc = 0
times = []
for epoch in range(1, args.epochs + 1):
start = time.time()
Expand Down
89 changes: 89 additions & 0 deletions examples/mixhop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os.path as osp

import torch
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import BatchNorm, Linear, MixHopConv

if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, name='Cora')
data = dataset[0]


class MixHop(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = MixHopConv(dataset.num_features, 60, powers=[0, 1, 2])
self.norm1 = BatchNorm(3 * 60)

self.conv2 = MixHopConv(3 * 60, 60, powers=[0, 1, 2])
self.norm2 = BatchNorm(3 * 60)

self.conv3 = MixHopConv(3 * 60, 60, powers=[0, 1, 2])
self.norm3 = BatchNorm(3 * 60)

self.lin = Linear(3 * 60, dataset.num_classes)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.7, training=self.training)

x = self.conv1(x, edge_index)
x = self.norm1(x)
x = F.dropout(x, p=0.9, training=self.training)

x = self.conv2(x, edge_index)
x = self.norm2(x)
x = F.dropout(x, p=0.9, training=self.training)

x = self.conv3(x, edge_index)
x = self.norm3(x)
x = F.dropout(x, p=0.9, training=self.training)

return self.lin(x)


model, data = MixHop().to(device), data.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=0.005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40,
gamma=0.01)


def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
scheduler.step()
return float(loss)


@torch.no_grad()
def test():
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=-1)

accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
return accs


best_val_acc = test_acc = 0
for epoch in range(1, 101):
loss = train()
train_acc, val_acc, tmp_test_acc = test()
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')
44 changes: 44 additions & 0 deletions test/nn/conv/test_mixhop_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch

import torch_geometric.typing
from torch_geometric.nn import MixHopConv
from torch_geometric.testing import is_full_test
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import to_torch_csc_tensor


def test_mixhop_conv():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
value = torch.rand(edge_index.size(1))
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))
adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))

conv = MixHopConv(16, 32, powers=[0, 1, 2, 4])
assert str(conv) == 'MixHopConv(16, 32, powers=[0, 1, 2, 4])'

out1 = conv(x, edge_index)
assert out1.size() == (4, 128)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)

out2 = conv(x, edge_index, value)
assert out2.size() == (4, 128)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)
assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)

if is_full_test() and torch_geometric.typing.WITH_TORCH_SPARSE:
t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)
assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .gps_conv import GPSConv
from .antisymmetric_conv import AntiSymmetricConv
from .dir_gnn_conv import DirGNNConv
from .mixhop_conv import MixHopConv

__all__ = [
'MessagePassing',
Expand Down Expand Up @@ -127,6 +128,7 @@
'GPSConv',
'AntiSymmetricConv',
'DirGNNConv',
'MixHopConv',
]

classes = __all__
Expand Down
125 changes: 125 additions & 0 deletions torch_geometric/nn/conv/mixhop_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import List, Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils import spmm


class MixHopConv(MessagePassing):
r"""The Mix-Hop graph convolutional operator from the
`"MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified
Neighborhood Mixing" <https://arxiv.org/abs/1905.00067>`_ paper
.. math::
\mathbf{X}^{\prime}={\Bigg\Vert}_{p\in P}
{\left( \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \right)}^p \mathbf{X} \mathbf{\Theta},
where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
adjacency matrix with inserted self-loops and
:math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.
Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
powers (List[int], optional): The powers of the adjacency matrix to
use. (default: :obj:`[0, 1, 2]`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge weights :math:`(|\mathcal{E}|)` *(optional)*
- **output:**
node features :math:`(|\mathcal{V}|, |P| \cdot F_{out})`
"""
def __init__(
self,
in_channels: int,
out_channels: int,
powers: Optional[List[int]] = None,
add_self_loops: bool = True,
bias: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(**kwargs)

if powers is None:
powers = [0, 1, 2]

self.in_channels = in_channels
self.out_channels = out_channels
self.powers = powers
self.add_self_loops = add_self_loops

self.lins = torch.nn.ModuleList([
Linear(in_channels, out_channels, bias=False)
if p in powers else torch.nn.Identity()
for p in range(max(powers) + 1)
])

if bias:
self.bias = Parameter(torch.empty(len(powers) * out_channels))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
for lin in self.lins:
if hasattr(lin, 'reset_parameters'):
lin.reset_parameters()
zeros(self.bias)

def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:

if isinstance(edge_index, Tensor):
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim), False,
self.add_self_loops, self.flow, x.dtype)
elif isinstance(edge_index, SparseTensor):
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim), False,
self.add_self_loops, self.flow, x.dtype)

outs = [self.lins[0](x)]

for lin in self.lins[1:]:
# propagate_type: (x: Tensor, edge_weight: OptTensor)
x = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)

outs.append(lin.forward(x))

out = torch.cat([outs[p] for p in self.powers], dim=-1)

if self.bias is not None:
out = out + self.bias

return out

def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return spmm(adj_t, x, reduce=self.aggr)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, powers={self.powers})')

0 comments on commit 3cb3c40

Please sign in to comment.