Skip to content

Commit

Permalink
Support for multiple aggregations in SAGEConv (#5033)
Browse files Browse the repository at this point in the history
* Support auto-infer channels for multiple aggregations in SAGEConv

* Avoid overriding aggr_kwargs for lstm

* changelog

* update

* Add get_out_channels to multi aggr

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add TODO

* update

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>

* update

Co-authored-by: Guohao Li <lighaime@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people authored Jul 27, 2022
1 parent 333d3d3 commit 23a7be0
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 16 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
22 changes: 21 additions & 1 deletion test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_sage_conv(project):
assert jit((x1, None), adj.t()).tolist() == out2.tolist()


def test_lstm_sage_conv():
def test_lstm_aggr_sage_conv():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
Expand All @@ -71,3 +71,23 @@ def test_lstm_sage_conv():
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]])
with pytest.raises(ValueError, match="'index' tensor is not sorted"):
conv(x, edge_index)


@pytest.mark.parametrize('aggr_kwargs', [
dict(mode='cat'),
dict(mode='proj', mode_kwargs=dict(in_channels=8, out_channels=16)),
dict(mode='attn', mode_kwargs=dict(in_channels=8, out_channels=16,
num_heads=4)),
dict(mode='sum'),
])
def test_multi_aggr_sage_conv(aggr_kwargs):
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
aggr_kwargs['aggrs_kwargs'] = [{}, {}, {}, dict(learn=True, t=1)]
conv = SAGEConv(8, 32, aggr=['mean', 'max', 'sum', 'softmax'],
aggr_kwargs=aggr_kwargs)
out = conv(x, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x, adj.t()), out)
28 changes: 19 additions & 9 deletions torch_geometric/nn/aggr/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,34 +65,36 @@ def __init__(

self.mode = mode
mode_kwargs = mode_kwargs or {}
self.in_channels = mode_kwargs.pop('in_channels', None)
self.out_channels = mode_kwargs.pop('out_channels', None)
if mode == 'proj' or mode == 'attn':
if len(aggrs) == 1:
raise ValueError("Multiple aggregations are required for "
"'proj' or 'attn' combine mode.")
in_channels = mode_kwargs.pop('in_channels', None)
out_channels = mode_kwargs.pop('out_channels', None)
if (in_channels and out_channels) is None:

if (self.in_channels and self.out_channels) is None:
raise ValueError(
f"Combine mode '{mode}' must have `in_channels` "
f"and `out_channels` specified.")

if isinstance(in_channels, int):
in_channels = (in_channels, ) * len(aggrs)
if isinstance(self.in_channels, int):
self.in_channels = (self.in_channels, ) * len(aggrs)

if mode == 'proj':
self.lin = Linear(
sum(in_channels),
out_channels,
sum(self.in_channels),
self.out_channels,
**mode_kwargs,
)

if mode == 'attn':
self.lin_heads = torch.nn.ModuleList([
Linear(channels, out_channels) for channels in in_channels
Linear(channels, self.out_channels)
for channels in self.in_channels
])
num_heads = mode_kwargs.pop('num_heads', 1)
self.multihead_attn = MultiheadAttention(
out_channels,
self.out_channels,
num_heads,
**mode_kwargs,
)
Expand All @@ -114,6 +116,14 @@ def reset_parameters(self):
if hasattr(self, 'multihead_attn'):
self.multihead_attn._reset_parameters()

def get_out_channels(self, in_channels: int) -> int:
if self.out_channels is not None:
return self.out_channels
# TODO: Support having customized `out_channels` in each aggregation
if self.mode == 'cat':
return in_channels * len(self.aggrs)
return in_channels

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
Expand Down
18 changes: 13 additions & 5 deletions torch_geometric/nn/conv/sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union

import torch.nn.functional as F
from torch import Tensor
from torch.nn import LSTM
from torch_sparse import SparseTensor, matmul

from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
aggr: str = 'mean',
aggr: Optional[Union[str, List[str], Aggregation]] = "mean",
normalize: bool = False,
root_weight: bool = True,
project: bool = False,
Expand All @@ -83,8 +84,9 @@ def __init__(
in_channels = (in_channels, in_channels)

if aggr == 'lstm':
kwargs['aggr_kwargs'] = dict(in_channels=in_channels[0],
out_channels=in_channels[0])
kwargs.setdefault('aggr_kwargs', {})
kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0])
kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0])

super().__init__(aggr, **kwargs)

Expand All @@ -95,7 +97,13 @@ def __init__(
self.fuse = False # No "fused" message_and_aggregate.
self.lstm = LSTM(in_channels[0], in_channels[0], batch_first=True)

self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
if isinstance(self.aggr_module, MultiAggregation):
aggr_out_channels = self.aggr_module.get_out_channels(
in_channels[0])
else:
aggr_out_channels = in_channels[0]

self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias)
if self.root_weight:
self.lin_r = Linear(in_channels[1], out_channels, bias=False)

Expand Down

0 comments on commit 23a7be0

Please sign in to comment.