Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiple aggregations in SAGEConv #5033

Merged
merged 13 commits into from
Jul 27, 2022
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
22 changes: 21 additions & 1 deletion test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_sage_conv(project):
assert jit((x1, None), adj.t()).tolist() == out2.tolist()


def test_lstm_sage_conv():
def test_lstm_aggr_sage_conv():
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
Expand All @@ -71,3 +71,23 @@ def test_lstm_sage_conv():
edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]])
with pytest.raises(ValueError, match="'index' tensor is not sorted"):
conv(x, edge_index)


@pytest.mark.parametrize('aggr_kwargs', [
dict(mode='cat'),
dict(mode='proj', mode_kwargs=dict(in_channels=8, out_channels=16)),
dict(mode='attn', mode_kwargs=dict(in_channels=8, out_channels=16,
num_heads=4)),
dict(mode='sum'),
])
def test_multi_aggr_sage_conv(aggr_kwargs):
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
row, col = edge_index
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))
aggr_kwargs['aggrs_kwargs'] = [{}, {}, {}, dict(learn=True, t=1)]
conv = SAGEConv(8, 32, aggr=['mean', 'max', 'sum', 'softmax'],
aggr_kwargs=aggr_kwargs)
out = conv(x, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x, adj.t()), out)
28 changes: 19 additions & 9 deletions torch_geometric/nn/aggr/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,34 +65,36 @@ def __init__(

self.mode = mode
mode_kwargs = mode_kwargs or {}
self.in_channels = mode_kwargs.pop('in_channels', None)
self.out_channels = mode_kwargs.pop('out_channels', None)
if mode == 'proj' or mode == 'attn':
if len(aggrs) == 1:
raise ValueError("Multiple aggregations are required for "
"'proj' or 'attn' combine mode.")
in_channels = mode_kwargs.pop('in_channels', None)
out_channels = mode_kwargs.pop('out_channels', None)
if (in_channels and out_channels) is None:

if (self.in_channels and self.out_channels) is None:
raise ValueError(
f"Combine mode '{mode}' must have `in_channels` "
f"and `out_channels` specified.")

if isinstance(in_channels, int):
in_channels = (in_channels, ) * len(aggrs)
if isinstance(self.in_channels, int):
self.in_channels = (self.in_channels, ) * len(aggrs)

if mode == 'proj':
self.lin = Linear(
sum(in_channels),
out_channels,
sum(self.in_channels),
self.out_channels,
**mode_kwargs,
)

if mode == 'attn':
self.lin_heads = torch.nn.ModuleList([
Linear(channels, out_channels) for channels in in_channels
Linear(channels, self.out_channels)
for channels in self.in_channels
])
num_heads = mode_kwargs.pop('num_heads', 1)
self.multihead_attn = MultiheadAttention(
out_channels,
self.out_channels,
num_heads,
**mode_kwargs,
)
Expand All @@ -114,6 +116,14 @@ def reset_parameters(self):
if hasattr(self, 'multihead_attn'):
self.multihead_attn._reset_parameters()

def get_out_channels(self, in_channels: int) -> int:
if self.out_channels is not None:
return self.out_channels
# TODO: Support having customized `out_channels` in each aggregation
if self.mode == 'cat':
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return in_channels * len(self.aggrs)
return in_channels

def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:
Expand Down
18 changes: 13 additions & 5 deletions torch_geometric/nn/conv/sage_conv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union

import torch.nn.functional as F
from torch import Tensor
from torch.nn import LSTM
from torch_sparse import SparseTensor, matmul

from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size
Expand Down Expand Up @@ -66,7 +67,7 @@ def __init__(
self,
in_channels: Union[int, Tuple[int, int]],
out_channels: int,
aggr: str = 'mean',
aggr: Optional[Union[str, List[str], Aggregation]] = "mean",
normalize: bool = False,
root_weight: bool = True,
project: bool = False,
Expand All @@ -83,8 +84,9 @@ def __init__(
in_channels = (in_channels, in_channels)

if aggr == 'lstm':
kwargs['aggr_kwargs'] = dict(in_channels=in_channels[0],
out_channels=in_channels[0])
kwargs.setdefault('aggr_kwargs', {})
kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0])
kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0])

super().__init__(aggr, **kwargs)

Expand All @@ -95,7 +97,13 @@ def __init__(
self.fuse = False # No "fused" message_and_aggregate.
self.lstm = LSTM(in_channels[0], in_channels[0], batch_first=True)

self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
if isinstance(self.aggr_module, MultiAggregation):
aggr_out_channels = self.aggr_module.get_out_channels(
in_channels[0])
else:
aggr_out_channels = in_channels[0]

self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias)
if self.root_weight:
self.lin_r = Linear(in_channels[1], out_channels, bias=False)

Expand Down