Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MixHopConv to torch_geometric.nn.conv #8025

Merged
merged 12 commits into from
Sep 14, 2023
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))
- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024))
- Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230))
- Added the `NeuralFingerprint` model for learning fingerprints of molecules ([#7919](https://github.com/pyg-team/pytorch_geometric/pull/7919))
Expand Down
2 changes: 1 addition & 1 deletion examples/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test():
return accs


best_val_acc = final_test_acc = 0
best_val_acc = test_acc = 0
times = []
for epoch in range(1, args.epochs + 1):
start = time.time()
Expand Down
89 changes: 89 additions & 0 deletions examples/mixhop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os.path as osp

import torch
import torch.nn.functional as F

from torch_geometric.datasets import Planetoid
from torch_geometric.nn import BatchNorm, Linear, MixHopConv

if torch.cuda.is_available():
device = torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, name='Cora')
data = dataset[0]


class MixHop(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = MixHopConv(dataset.num_features, 60, powers=[0, 1, 2])
self.norm1 = BatchNorm(3 * 60)

self.conv2 = MixHopConv(3 * 60, 60, powers=[0, 1, 2])
self.norm2 = BatchNorm(3 * 60)

self.conv3 = MixHopConv(3 * 60, 60, powers=[0, 1, 2])
self.norm3 = BatchNorm(3 * 60)

self.lin = Linear(3 * 60, dataset.num_classes)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.7, training=self.training)

x = self.conv1(x, edge_index)
x = self.norm1(x)
x = F.dropout(x, p=0.9, training=self.training)

x = self.conv2(x, edge_index)
x = self.norm2(x)
x = F.dropout(x, p=0.9, training=self.training)

x = self.conv3(x, edge_index)
x = self.norm3(x)
x = F.dropout(x, p=0.9, training=self.training)

return self.lin(x)


model, data = MixHop().to(device), data.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=0.005)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40,
gamma=0.01)


def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
scheduler.step()
return float(loss)


@torch.no_grad()
def test():
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=-1)

accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
return accs


best_val_acc = test_acc = 0
for epoch in range(1, 101):
loss = train()
train_acc, val_acc, tmp_test_acc = test()
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '
f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')
44 changes: 44 additions & 0 deletions test/nn/conv/test_mixhop_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch

import torch_geometric.typing
from torch_geometric.nn import MixHopConv
from torch_geometric.testing import is_full_test
from torch_geometric.typing import SparseTensor
from torch_geometric.utils import to_torch_csc_tensor


def test_mixhop_conv():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
value = torch.rand(edge_index.size(1))
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))
adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))

conv = MixHopConv(16, 32, powers=[0, 1, 2, 4])
assert str(conv) == 'MixHopConv(16, 32, powers=[0, 1, 2, 4])'

out1 = conv(x, edge_index)
assert out1.size() == (4, 128)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)

out2 = conv(x, edge_index, value)
assert out2.size() == (4, 128)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)

if torch_geometric.typing.WITH_TORCH_SPARSE:
adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))
adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)

if is_full_test():
t = '(Tensor, Tensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)
assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)

if is_full_test() and torch_geometric.typing.WITH_TORCH_SPARSE:
t = '(Tensor, SparseTensor, OptTensor) -> Tensor'
jit = torch.jit.script(conv.jittable(t))
assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)
assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
from .gps_conv import GPSConv
from .antisymmetric_conv import AntiSymmetricConv
from .dir_gnn_conv import DirGNNConv
from .mixhop_conv import MixHopConv

__all__ = [
'MessagePassing',
Expand Down Expand Up @@ -127,6 +128,7 @@
'GPSConv',
'AntiSymmetricConv',
'DirGNNConv',
'MixHopConv',
]

classes = __all__
Expand Down
125 changes: 125 additions & 0 deletions torch_geometric/nn/conv/mixhop_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import List, Optional

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from torch_geometric.utils import spmm


class MixHopConv(MessagePassing):
r"""The Mix-Hop graph convolutional operator from the
`"MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified
Neighborhood Mixing" <https://arxiv.org/abs/1905.00067>`_ paper

.. math::
\mathbf{X}^{\prime}={\Bigg\Vert}_{p\in P}
{\left( \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2} \right)}^p \mathbf{X} \mathbf{\Theta},

where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the
adjacency matrix with inserted self-loops and
:math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix.

Args:
in_channels (int): Size of each input sample, or :obj:`-1` to derive
the size from the first input(s) to the forward method.
out_channels (int): Size of each output sample.
powers (List[int], optional): The powers of the adjacency matrix to
use. (default: :obj:`[0, 1, 2]`)
add_self_loops (bool, optional): If set to :obj:`False`, will not add
self-loops to the input graph. (default: :obj:`True`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.

Shapes:
- **input:**
node features :math:`(|\mathcal{V}|, F_{in})`,
edge indices :math:`(2, |\mathcal{E}|)`,
edge weights :math:`(|\mathcal{E}|)` *(optional)*
- **output:**
node features :math:`(|\mathcal{V}|, |P| \cdot F_{out})`
"""
def __init__(
self,
in_channels: int,
out_channels: int,
powers: Optional[List[int]] = None,
add_self_loops: bool = True,
bias: bool = True,
**kwargs,
):
kwargs.setdefault('aggr', 'add')
super().__init__(**kwargs)

if powers is None:
powers = [0, 1, 2]

self.in_channels = in_channels
self.out_channels = out_channels
self.powers = powers
self.add_self_loops = add_self_loops

self.lins = torch.nn.ModuleList([
Linear(in_channels, out_channels, bias=False)
if p in powers else torch.nn.Identity()
for p in range(max(powers) + 1)
])

if bias:
self.bias = Parameter(torch.empty(len(powers) * out_channels))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
for lin in self.lins:
if hasattr(lin, 'reset_parameters'):
lin.reset_parameters()
zeros(self.bias)

def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:

if isinstance(edge_index, Tensor):
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim), False,
self.add_self_loops, self.flow, x.dtype)
elif isinstance(edge_index, SparseTensor):
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim), False,
self.add_self_loops, self.flow, x.dtype)

outs = [self.lins[0](x)]

for lin in self.lins[1:]:
# propagate_type: (x: Tensor, edge_weight: OptTensor)
x = self.propagate(edge_index, x=x, edge_weight=edge_weight,
size=None)

outs.append(lin.forward(x))

out = torch.cat([outs[p] for p in self.powers], dim=-1)

if self.bias is not None:
out = out + self.bias

return out

def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return spmm(adj_t, x, reduce=self.aggr)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, powers={self.powers})')