Skip to content

Commit

Permalink
Integration of nn.aggr within MessagPassing (#4779)
Browse files Browse the repository at this point in the history
* Add sum alias class

* Add message passing integration with nn.aggr

* Cleanup MessagePassing

* Raise errors in resolver

* changelog

* Add arg_kwargs support for

* update

* error

* update

* update

* fix test

* update

* update

* typo

* typo

* typo

* reset

* update

* update

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
lightaime and rusty1s authored Jun 25, 2022
1 parent a5f833c commit d700ddb
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 107 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
32 changes: 31 additions & 1 deletion test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch_sparse import SparseTensor
from torch_sparse.matmul import spmm

from torch_geometric.nn import MessagePassing
from torch_geometric.nn import MessagePassing, aggr
from torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size


Expand Down Expand Up @@ -487,3 +487,33 @@ def test_explain_message():
conv._edge_mask = torch.tensor([0, 0, 0, 0], dtype=torch.float)
conv._apply_sigmoid = False
assert conv(x, edge_index).abs().sum() == 0.


class MyAggregatorConv(MessagePassing):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
# propagate_type: (x: TEnsor)
return self.propagate(edge_index, x=x, size=None)


@pytest.mark.parametrize('aggr_module', [
aggr.MeanAggregation(),
aggr.SumAggregation(),
aggr.MaxAggregation(),
aggr.SoftmaxAggregation(),
aggr.PowerMeanAggregation(),
aggr.MultiAggregation(['mean', 'max'])
])
def test_message_passing_with_aggr_module(aggr_module):
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

conv = MyAggregatorConv(aggr=aggr_module)
assert isinstance(conv.aggr_module, aggr.Aggregation)
out = conv(x, edge_index)
assert out.size(0) == 4 and out.size(1) in {8, 16}
assert torch.allclose(conv(x, adj.t()), out)
1 change: 1 addition & 0 deletions test/nn/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_activation_resolver():
@pytest.mark.parametrize('aggr_tuple', [
(torch_geometric.nn.aggr.MeanAggregation, 'mean'),
(torch_geometric.nn.aggr.SumAggregation, 'sum'),
(torch_geometric.nn.aggr.SumAggregation, 'add'),
(torch_geometric.nn.aggr.MaxAggregation, 'max'),
(torch_geometric.nn.aggr.MinAggregation, 'min'),
(torch_geometric.nn.aggr.MulAggregation, 'mul'),
Expand Down
2 changes: 0 additions & 2 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .basic import (
MeanAggregation,
SumAggregation,
AddAggregation,
MaxAggregation,
MinAggregation,
MulAggregation,
Expand All @@ -20,7 +19,6 @@
'MultiAggregation',
'MeanAggregation',
'SumAggregation',
'AddAggregation',
'MaxAggregation',
'MinAggregation',
'MulAggregation',
Expand Down
6 changes: 3 additions & 3 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
class Aggregation(torch.nn.Module, ABC):
r"""An abstract base class for implementing custom aggregations."""
@abstractmethod
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
r"""
Expand All @@ -35,7 +35,7 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def reset_parameters(self):
pass

def __call__(self, x: Tensor, index: Optional[Tensor] = None, *,
def __call__(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

Expand All @@ -62,7 +62,7 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *,
f"'{dim_size}' but expected "
f">= '{int(index.max()) + 1}')")

return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim)
return super().__call__(x, index, ptr, dim_size, dim)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
Expand Down
25 changes: 10 additions & 15 deletions torch_geometric/nn/aggr/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,38 +9,35 @@


class MeanAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='mean')


class SumAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')


AddAggregation = SumAggregation # Alias


class MaxAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='max')


class MinAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
return self.reduce(x, index, ptr, dim_size, dim, reduce='min')


class MulAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
# TODO Currently, `mul` reduction can only operate on `index`:
Expand All @@ -49,21 +46,19 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *,


class VarAggregation(Aggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean')
mean_2 = self.reduce(x * x, index, ptr, dim_size, dim, reduce='mean')
return mean_2 - mean * mean


class StdAggregation(VarAggregation):
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

var = super().forward(x, index, ptr=ptr, dim_size=dim_size, dim=dim)
var = super().forward(x, index, ptr, dim_size, dim)
return torch.sqrt(var.relu() + 1e-5)


Expand All @@ -80,7 +75,7 @@ def reset_parameters(self):
if isinstance(self.t, Tensor):
self.t.data.fill_(self._init_t)

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

Expand All @@ -107,7 +102,7 @@ def reset_parameters(self):
if isinstance(self.p, Tensor):
self.p.data.fill_(self._init_p)

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/aggr/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, in_channels: int, out_channels: int, **kwargs):
def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim)
Expand Down
25 changes: 19 additions & 6 deletions torch_geometric/nn/aggr/multi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor
Expand All @@ -8,7 +8,8 @@


class MultiAggregation(Aggregation):
def __init__(self, aggrs: List[Union[Aggregation, str]]):
def __init__(self, aggrs: List[Union[Aggregation, str]],
aggrs_kwargs: Optional[List[Dict[str, Any]]] = None):
super().__init__()

if not isinstance(aggrs, (list, tuple)):
Expand All @@ -19,14 +20,26 @@ def __init__(self, aggrs: List[Union[Aggregation, str]]):
raise ValueError(f"'aggrs' of '{self.__class__.__name__}' should "
f"not be empty")

self.aggrs = [aggregation_resolver(aggr) for aggr in aggrs]

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
if aggrs_kwargs is None:
aggrs_kwargs = [{}] * len(aggrs)
elif len(aggrs) != len(aggrs_kwargs):
raise ValueError(f"'aggrs_kwargs' with invalid length passed to "
f"'{self.__class__.__name__}' "
f"(got '{len(aggrs_kwargs)}', "
f"expected '{len(aggrs)}'). Ensure that both "
f"'aggrs' and 'aggrs_kwargs' are consistent")

self.aggrs = torch.nn.ModuleList([
aggregation_resolver(aggr, **aggr_kwargs)
for aggr, aggr_kwargs in zip(aggrs, aggrs_kwargs)
])

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
outs = []
for aggr in self.aggrs:
outs.append(aggr(x, index, ptr=ptr, dim_size=dim_size, dim=dim))
outs.append(aggr(x, index, ptr, dim_size, dim))
return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0]

def __repr__(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/aggr/set2set.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, in_channels: int, processing_steps: int, **kwargs):
def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

Expand Down
9 changes: 1 addition & 8 deletions torch_geometric/nn/conv/message_passing.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,15 @@ class {{cls_name}}({{parent_cls_name}}):
the_size = self.__check_input__(edge_index, size)
in_kwargs = Propagate_{{uid}}({% for k in prop_types.keys() %}{{k}}={{k}}{{ ", " if not loop.last }}{% endfor %})

{% if fuse and single_aggr %}
{% if fuse %}
if isinstance(edge_index, SparseTensor):
out = self.message_and_aggregate(edge_index{% for k in msg_and_aggr_args %}, {{k}}=in_kwargs.{{k}}{% endfor %})
return self.update(out{% for k in update_args %}, {{k}}=in_kwargs.{{k}}{% endfor %})
{% endif %}

kwargs = self.__collect__(edge_index, the_size, in_kwargs)
out = self.message({% for k in msg_args %}{{k}}=kwargs.{{k}}{{ ", " if not loop.last }}{% endfor %})
{% if single_aggr %}
out = self.aggregate(out{% for k in aggr_args %}, {{k}}=kwargs.{{k}}{% endfor %})
{% else %}
outs: List[Tensor] = []
for aggr in self.aggrs:
outs.append(self.aggregate(out{% for k in aggr_args %}, {{k}}=kwargs.{{k}}{% endfor %}, aggr=aggr))
out = self.combine(outs)
{% endif %}
return self.update(out{% for k in update_args %}, {{k}}=kwargs.{{k}}{% endfor %})

{% if edge_updater_types|length > 0 %}
Expand Down
Loading

0 comments on commit d700ddb

Please sign in to comment.