-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
From #8022 - Add new operator `MixHopConv` in `nn.conv` - Add a test for it - Add an example for it Feel free to comment, thanks:) --------- Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
- Loading branch information
1 parent
f06cbcd
commit 3cb3c40
Showing
6 changed files
with
262 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import os.path as osp | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from torch_geometric.datasets import Planetoid | ||
from torch_geometric.nn import BatchNorm, Linear, MixHopConv | ||
|
||
if torch.cuda.is_available(): | ||
device = torch.device('cuda') | ||
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): | ||
device = torch.device('mps') | ||
else: | ||
device = torch.device('cpu') | ||
|
||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') | ||
dataset = Planetoid(path, name='Cora') | ||
data = dataset[0] | ||
|
||
|
||
class MixHop(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = MixHopConv(dataset.num_features, 60, powers=[0, 1, 2]) | ||
self.norm1 = BatchNorm(3 * 60) | ||
|
||
self.conv2 = MixHopConv(3 * 60, 60, powers=[0, 1, 2]) | ||
self.norm2 = BatchNorm(3 * 60) | ||
|
||
self.conv3 = MixHopConv(3 * 60, 60, powers=[0, 1, 2]) | ||
self.norm3 = BatchNorm(3 * 60) | ||
|
||
self.lin = Linear(3 * 60, dataset.num_classes) | ||
|
||
def forward(self, x, edge_index): | ||
x = F.dropout(x, p=0.7, training=self.training) | ||
|
||
x = self.conv1(x, edge_index) | ||
x = self.norm1(x) | ||
x = F.dropout(x, p=0.9, training=self.training) | ||
|
||
x = self.conv2(x, edge_index) | ||
x = self.norm2(x) | ||
x = F.dropout(x, p=0.9, training=self.training) | ||
|
||
x = self.conv3(x, edge_index) | ||
x = self.norm3(x) | ||
x = F.dropout(x, p=0.9, training=self.training) | ||
|
||
return self.lin(x) | ||
|
||
|
||
model, data = MixHop().to(device), data.to(device) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=0.005) | ||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40, | ||
gamma=0.01) | ||
|
||
|
||
def train(): | ||
model.train() | ||
optimizer.zero_grad() | ||
out = model(data.x, data.edge_index) | ||
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) | ||
loss.backward() | ||
optimizer.step() | ||
scheduler.step() | ||
return float(loss) | ||
|
||
|
||
@torch.no_grad() | ||
def test(): | ||
model.eval() | ||
pred = model(data.x, data.edge_index).argmax(dim=-1) | ||
|
||
accs = [] | ||
for mask in [data.train_mask, data.val_mask, data.test_mask]: | ||
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum())) | ||
return accs | ||
|
||
|
||
best_val_acc = test_acc = 0 | ||
for epoch in range(1, 101): | ||
loss = train() | ||
train_acc, val_acc, tmp_test_acc = test() | ||
if val_acc > best_val_acc: | ||
best_val_acc = val_acc | ||
test_acc = tmp_test_acc | ||
print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, ' | ||
f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
|
||
import torch_geometric.typing | ||
from torch_geometric.nn import MixHopConv | ||
from torch_geometric.testing import is_full_test | ||
from torch_geometric.typing import SparseTensor | ||
from torch_geometric.utils import to_torch_csc_tensor | ||
|
||
|
||
def test_mixhop_conv(): | ||
x = torch.randn(4, 16) | ||
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) | ||
value = torch.rand(edge_index.size(1)) | ||
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4)) | ||
adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4)) | ||
|
||
conv = MixHopConv(16, 32, powers=[0, 1, 2, 4]) | ||
assert str(conv) == 'MixHopConv(16, 32, powers=[0, 1, 2, 4])' | ||
|
||
out1 = conv(x, edge_index) | ||
assert out1.size() == (4, 128) | ||
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6) | ||
|
||
out2 = conv(x, edge_index, value) | ||
assert out2.size() == (4, 128) | ||
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6) | ||
|
||
if torch_geometric.typing.WITH_TORCH_SPARSE: | ||
adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) | ||
adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) | ||
assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6) | ||
assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6) | ||
|
||
if is_full_test(): | ||
t = '(Tensor, Tensor, OptTensor) -> Tensor' | ||
jit = torch.jit.script(conv.jittable(t)) | ||
assert torch.allclose(jit(x, edge_index), out1, atol=1e-6) | ||
assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6) | ||
|
||
if is_full_test() and torch_geometric.typing.WITH_TORCH_SPARSE: | ||
t = '(Tensor, SparseTensor, OptTensor) -> Tensor' | ||
jit = torch.jit.script(conv.jittable(t)) | ||
assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6) | ||
assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import List, Optional | ||
|
||
import torch | ||
from torch import Tensor | ||
from torch.nn import Parameter | ||
|
||
from torch_geometric.nn.conv import MessagePassing | ||
from torch_geometric.nn.conv.gcn_conv import gcn_norm | ||
from torch_geometric.nn.dense.linear import Linear | ||
from torch_geometric.nn.inits import zeros | ||
from torch_geometric.typing import Adj, OptTensor, SparseTensor | ||
from torch_geometric.utils import spmm | ||
|
||
|
||
class MixHopConv(MessagePassing): | ||
r"""The Mix-Hop graph convolutional operator from the | ||
`"MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified | ||
Neighborhood Mixing" <https://arxiv.org/abs/1905.00067>`_ paper | ||
.. math:: | ||
\mathbf{X}^{\prime}={\Bigg\Vert}_{p\in P} | ||
{\left( \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} | ||
\mathbf{\hat{D}}^{-1/2} \right)}^p \mathbf{X} \mathbf{\Theta}, | ||
where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the | ||
adjacency matrix with inserted self-loops and | ||
:math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. | ||
Args: | ||
in_channels (int): Size of each input sample, or :obj:`-1` to derive | ||
the size from the first input(s) to the forward method. | ||
out_channels (int): Size of each output sample. | ||
powers (List[int], optional): The powers of the adjacency matrix to | ||
use. (default: :obj:`[0, 1, 2]`) | ||
add_self_loops (bool, optional): If set to :obj:`False`, will not add | ||
self-loops to the input graph. (default: :obj:`True`) | ||
bias (bool, optional): If set to :obj:`False`, the layer will not learn | ||
an additive bias. (default: :obj:`True`) | ||
**kwargs (optional): Additional arguments of | ||
:class:`torch_geometric.nn.conv.MessagePassing`. | ||
Shapes: | ||
- **input:** | ||
node features :math:`(|\mathcal{V}|, F_{in})`, | ||
edge indices :math:`(2, |\mathcal{E}|)`, | ||
edge weights :math:`(|\mathcal{E}|)` *(optional)* | ||
- **output:** | ||
node features :math:`(|\mathcal{V}|, |P| \cdot F_{out})` | ||
""" | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
powers: Optional[List[int]] = None, | ||
add_self_loops: bool = True, | ||
bias: bool = True, | ||
**kwargs, | ||
): | ||
kwargs.setdefault('aggr', 'add') | ||
super().__init__(**kwargs) | ||
|
||
if powers is None: | ||
powers = [0, 1, 2] | ||
|
||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.powers = powers | ||
self.add_self_loops = add_self_loops | ||
|
||
self.lins = torch.nn.ModuleList([ | ||
Linear(in_channels, out_channels, bias=False) | ||
if p in powers else torch.nn.Identity() | ||
for p in range(max(powers) + 1) | ||
]) | ||
|
||
if bias: | ||
self.bias = Parameter(torch.empty(len(powers) * out_channels)) | ||
else: | ||
self.register_parameter('bias', None) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
for lin in self.lins: | ||
if hasattr(lin, 'reset_parameters'): | ||
lin.reset_parameters() | ||
zeros(self.bias) | ||
|
||
def forward(self, x: Tensor, edge_index: Adj, | ||
edge_weight: OptTensor = None) -> Tensor: | ||
|
||
if isinstance(edge_index, Tensor): | ||
edge_index, edge_weight = gcn_norm( # yapf: disable | ||
edge_index, edge_weight, x.size(self.node_dim), False, | ||
self.add_self_loops, self.flow, x.dtype) | ||
elif isinstance(edge_index, SparseTensor): | ||
edge_index = gcn_norm( # yapf: disable | ||
edge_index, edge_weight, x.size(self.node_dim), False, | ||
self.add_self_loops, self.flow, x.dtype) | ||
|
||
outs = [self.lins[0](x)] | ||
|
||
for lin in self.lins[1:]: | ||
# propagate_type: (x: Tensor, edge_weight: OptTensor) | ||
x = self.propagate(edge_index, x=x, edge_weight=edge_weight, | ||
size=None) | ||
|
||
outs.append(lin.forward(x)) | ||
|
||
out = torch.cat([outs[p] for p in self.powers], dim=-1) | ||
|
||
if self.bias is not None: | ||
out = out + self.bias | ||
|
||
return out | ||
|
||
def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor: | ||
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j | ||
|
||
def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: | ||
return spmm(adj_t, x, reduce=self.aggr) | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}({self.in_channels}, ' | ||
f'{self.out_channels}, powers={self.powers})') |