Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiAggregation and aggregation_resolver #4749

Merged
merged 21 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
11 changes: 8 additions & 3 deletions test/nn/aggr/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
MaxAggregation,
MeanAggregation,
MinAggregation,
MulAggregation,
PowerMeanAggregation,
SoftmaxAggregation,
StdAggregation,
Expand Down Expand Up @@ -32,7 +33,7 @@ def test_validate():

@pytest.mark.parametrize('Aggregation', [
MeanAggregation, SumAggregation, MaxAggregation, MinAggregation,
VarAggregation, StdAggregation
MulAggregation, VarAggregation, StdAggregation
])
def test_basic_aggregation(Aggregation):
x = torch.randn(6, 16)
Expand All @@ -44,7 +45,11 @@ def test_basic_aggregation(Aggregation):

out = aggr(x, index)
assert out.size() == (3, x.size(1))
assert torch.allclose(out, aggr(x, ptr=ptr))
if isinstance(aggr, MulAggregation):
with pytest.raises(AssertionError, match="'index' is None"):
assert torch.allclose(out, aggr(x, ptr=ptr))
else:
assert torch.allclose(out, aggr(x, ptr=ptr))
lightaime marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize('Aggregation',
Expand All @@ -56,7 +61,7 @@ def test_gen_aggregation(Aggregation, learn):
ptr = torch.tensor([0, 2, 5, 6])

aggr = Aggregation(learn=learn)
assert str(aggr) == f'{Aggregation.__name__}()'
assert str(aggr) == f'{Aggregation.__name__}(learn={learn})'

out = aggr(x, index)
assert out.size() == (3, x.size(1))
Expand Down
122 changes: 122 additions & 0 deletions test/nn/aggr/test_mp_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from itertools import combinations

import pytest
import torch
from torch import Tensor
from torch_sparse import SparseTensor
from torch_sparse.matmul import spmm

from torch_geometric.nn import (
LSTMAggregation,
MaxAggregation,
MeanAggregation,
MessagePassing,
MinAggregation,
MulAggregation,
MultiAggregation,
PowerMeanAggregation,
SoftmaxAggregation,
StdAggregation,
SumAggregation,
VarAggregation,
)
from torch_geometric.typing import Adj


class MyConv(MessagePassing):
def __init__(self, aggr='mean'):
super().__init__(aggr=aggr)

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
# propagate_type: (x: Tensor)
return self.propagate(edge_index, x=x, size=None)

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return spmm(adj_t, x, reduce=self.aggr)


@pytest.mark.parametrize('aggr', [(MeanAggregation, 'mean'),
(SumAggregation, 'sum'),
(MaxAggregation, 'max'),
(MinAggregation, 'min'),
(MulAggregation, 'mul'),
(VarAggregation, 'var'),
(StdAggregation, 'std')])
def test_my_basic_aggr_conv(aggr):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

aggr_module, aggr_str = aggr
conv1 = MyConv(aggr=aggr_module())
out1 = conv1(x, edge_index)
assert out1.size() == (4, 16)
assert conv1(x, adj.t()).tolist() == out1.tolist()
conv1.fuse = False
assert conv1(x, adj.t()).tolist() == out1.tolist()
conv1.fuse = True

conv2 = MyConv(aggr=aggr_str)
out2 = conv2(x, edge_index)
assert out2.size() == (4, 16)
assert conv2(x, adj.t()).tolist() == out2.tolist()
conv2.fuse = False
assert conv2(x, adj.t()).tolist() == out2.tolist()
conv2.fuse = True

assert torch.allclose(out1, out2)


@pytest.mark.parametrize('Aggregation',
lightaime marked this conversation as resolved.
Show resolved Hide resolved
[SoftmaxAggregation, PowerMeanAggregation])
@pytest.mark.parametrize('learn', [True, False])
def test_my_gen_aggr_conv(Aggregation, learn):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
conv = MyConv(aggr=Aggregation(learn=learn))
out = conv(x, edge_index)
assert out.size() == (4, 16)


def test_my_lstm_aggr_conv():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
conv = MyConv(aggr=LSTMAggregation(16, 32))
out = conv(x, edge_index)
assert out.size() == (4, 32)


aggr_list = [
'mean', 'sum', 'max', 'min', 'mul', 'std', 'var',
SoftmaxAggregation(learn=True),
PowerMeanAggregation(learn=True),
LSTMAggregation(16, 16)
]
aggrs = [list(aggr) for aggr in combinations(aggr_list, 3)]


@pytest.mark.parametrize('aggrs', aggrs)
def test_my_multiple_aggr_conv(aggrs):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

conv = MyConv(aggr=MultiAggregation(aggrs=aggrs))
lightaime marked this conversation as resolved.
Show resolved Hide resolved
out = conv(x, edge_index)
assert out.size() == (4, 48)
assert not torch.allclose(out[:, 0:16], out[:, 16:32])
assert not torch.allclose(out[:, 0:16], out[:, 32:48])
assert not torch.allclose(out[:, 16:32], out[:, 32:48])


@pytest.mark.parametrize('aggrs', aggrs)
def test_my_list_multiple_aggr_conv(aggrs):
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])

conv = MyConv(aggr=aggrs)
out = conv(x, edge_index)
assert out.size() == (4, 48)
assert not torch.allclose(out[:, 0:16], out[:, 16:32])
assert not torch.allclose(out[:, 0:16], out[:, 32:48])
assert not torch.allclose(out[:, 16:32], out[:, 32:48])
39 changes: 39 additions & 0 deletions test/nn/aggr/test_multi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from itertools import combinations

import pytest
import torch

from torch_geometric.nn import (
LSTMAggregation,
MultiAggregation,
PowerMeanAggregation,
SoftmaxAggregation,
)

aggr_list = [
'mean', 'sum', 'max', 'min', 'mul', 'std', 'var',
SoftmaxAggregation(learn=True),
PowerMeanAggregation(learn=True),
LSTMAggregation(16, 16)
]
aggrs = [list(aggr) for aggr in combinations(aggr_list, 3)] + \
[[aggr] for aggr in aggr_list]


@pytest.mark.parametrize('aggrs', aggrs)
def test_multi_aggr(aggrs):
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 2])
aggr = MultiAggregation(aggrs=aggrs)
out = aggr(x, index)
assert out.size() == (3, len(aggrs) * 16)


def test_multi_aggr_repr_():
aggr = MultiAggregation(
aggrs=['max', 'min', PowerMeanAggregation(learn=True)])
assert str(aggr) == ('MultiAggregation([\n'
' MaxAggregation(),\n'
' MinAggregation(),\n'
' PowerMeanAggregation(learn=True)\n'
'])')
68 changes: 67 additions & 1 deletion test/nn/test_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
import pytest
import torch

from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.nn import (
LSTMAggregation,
MaxAggregation,
MeanAggregation,
MinAggregation,
MulAggregation,
PowerMeanAggregation,
SoftmaxAggregation,
StdAggregation,
SumAggregation,
VarAggregation,
)
from torch_geometric.nn.resolver import (
activation_resolver,
aggregation_resolver,
)


def test_activation_resolver():
Expand All @@ -11,3 +27,53 @@ def test_activation_resolver():
assert isinstance(activation_resolver('elu'), torch.nn.ELU)
assert isinstance(activation_resolver('relu'), torch.nn.ReLU)
assert isinstance(activation_resolver('prelu'), torch.nn.PReLU)


@pytest.mark.parametrize('aggr_test', [
(MeanAggregation, 'mean'),
(SumAggregation, 'sum'),
(MaxAggregation, 'max'),
(MinAggregation, 'min'),
(MulAggregation, 'mul'),
(VarAggregation, 'var'),
(StdAggregation, 'std'),
(SoftmaxAggregation, 'softmax'),
(PowerMeanAggregation, 'powermean'),
])
def test_aggregation_resolver_basic(aggr_test):
aggr_module, aggr_str = aggr_test
assert isinstance(aggregation_resolver(aggr_module()), aggr_module)
assert isinstance(aggregation_resolver(aggr_str), aggr_module)


@pytest.mark.parametrize('aggr_test', [
(SoftmaxAggregation, 'softmax', {
'learn': True
}),
(PowerMeanAggregation, 'powermean', {
'learn': True
}),
(LSTMAggregation, 'lstm', {
'in_channels': 16,
'out_channels': 32
}),
])
def test_aggregation_resolver_learnable(aggr_test):
aggr_module, aggr_str, kwargs = aggr_test
aggr = aggr_module(**kwargs)
res_aggr1 = aggregation_resolver(aggr_module(**kwargs))
res_aggr2 = aggregation_resolver(aggr_str, **kwargs)
assert isinstance(res_aggr1, aggr_module)
assert isinstance(res_aggr2, aggr_module)
if issubclass(aggr_module, SoftmaxAggregation):
repr = 'SoftmaxAggregation(learn=True)'
elif issubclass(aggr_module, PowerMeanAggregation):
repr = 'PowerMeanAggregation(learn=True)'
elif issubclass(aggr_module, LSTMAggregation):
repr = 'LSTMAggregation(16, 32)'
else:
raise ValueError("Unknown Aggregation")

assert str(aggr) == repr
assert str(res_aggr1) == repr
assert str(res_aggr2) == repr
4 changes: 4 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
SumAggregation,
MaxAggregation,
MinAggregation,
MulAggregation,
VarAggregation,
StdAggregation,
SoftmaxAggregation,
PowerMeanAggregation,
)
from .lstm import LSTMAggregation
from .multi import MultiAggregation

__all__ = classes = [
'MultiAggregation',
'Aggregation',
'MeanAggregation',
'SumAggregation',
'MaxAggregation',
'MinAggregation',
'MulAggregation',
'VarAggregation',
'StdAggregation',
'SoftmaxAggregation',
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def reduce(self, x: Tensor, index: Optional[Tensor] = None,
ptr = expand_left(ptr, dim, dims=x.dim())
return segment_csr(x, ptr, reduce=reduce)

assert index is not None
assert index is not None, "'index' is None"
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)

def __repr__(self) -> str:
Expand Down
24 changes: 23 additions & 1 deletion torch_geometric/nn/aggr/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import softmax

from .base import Aggregation


class MeanAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
Expand Down Expand Up @@ -36,6 +37,18 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
return self.reduce(x, index, ptr, dim_size, dim, reduce='min')


class MulAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
if ptr is not None: # TODO
import warnings
warnings.warn(f"'{self.__class__.__name__}' with "
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
f"'ptr' not yet supported, use 'index' instead")
ptr = None
return self.reduce(x, index, ptr, dim_size, dim, reduce='mul')


class VarAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
Expand All @@ -61,6 +74,7 @@ def __init__(self, t: float = 1.0, learn: bool = False):
super().__init__()
self._init_t = t
self.t = Parameter(torch.Tensor(1)) if learn else t
self.learn = learn
self.reset_parameters()

def reset_parameters(self):
Expand All @@ -77,15 +91,20 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
alpha = softmax(alpha, index, ptr, dim_size, dim)
return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(learn={self.learn})')


class PowerMeanAggregation(Aggregation):
def __init__(self, p: float = 1.0, learn: bool = False):
# TODO Learn distinct `p` per channel.
super().__init__()
self._init_p = p
self.p = Parameter(torch.Tensor(1)) if learn else p
self.learn = learn
self.reset_parameters()

def reset_parameters(self):
if isinstance(self.p, Tensor):
self.p.data.fill_(self._init_p)

Expand All @@ -97,3 +116,6 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
if isinstance(self.p, (int, float)) and self.p == 1:
return out
return out.clamp_(min=0, max=100).pow(1. / self.p)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(learn={self.learn})')
Loading