diff --git a/CHANGELOG.md b/CHANGELOG.md index 2585289b0cb7..76878c5c1d29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) - Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723)) -- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522)) +- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033])) - Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800)) - Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487)) - Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730)) diff --git a/test/nn/conv/test_sage_conv.py b/test/nn/conv/test_sage_conv.py index 58afa7ae53bd..c54838350678 100644 --- a/test/nn/conv/test_sage_conv.py +++ b/test/nn/conv/test_sage_conv.py @@ -56,7 +56,7 @@ def test_sage_conv(project): assert jit((x1, None), adj.t()).tolist() == out2.tolist() -def test_lstm_sage_conv(): +def test_lstm_aggr_sage_conv(): x = torch.randn(4, 8) edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) row, col = edge_index @@ -71,3 +71,23 @@ def test_lstm_sage_conv(): edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]]) with pytest.raises(ValueError, match="'index' tensor is not sorted"): conv(x, edge_index) + + +@pytest.mark.parametrize('aggr_kwargs', [ + dict(mode='cat'), + dict(mode='proj', mode_kwargs=dict(in_channels=8, out_channels=16)), + dict(mode='attn', mode_kwargs=dict(in_channels=8, out_channels=16, + num_heads=4)), + dict(mode='sum'), +]) +def test_multi_aggr_sage_conv(aggr_kwargs): + x = torch.randn(4, 8) + edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + row, col = edge_index + adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4)) + aggr_kwargs['aggrs_kwargs'] = [{}, {}, {}, dict(learn=True, t=1)] + conv = SAGEConv(8, 32, aggr=['mean', 'max', 'sum', 'softmax'], + aggr_kwargs=aggr_kwargs) + out = conv(x, edge_index) + assert out.size() == (4, 32) + assert torch.allclose(conv(x, adj.t()), out) diff --git a/torch_geometric/nn/aggr/multi.py b/torch_geometric/nn/aggr/multi.py index c601b2300b1c..6a3a67912c15 100644 --- a/torch_geometric/nn/aggr/multi.py +++ b/torch_geometric/nn/aggr/multi.py @@ -65,34 +65,36 @@ def __init__( self.mode = mode mode_kwargs = mode_kwargs or {} + self.in_channels = mode_kwargs.pop('in_channels', None) + self.out_channels = mode_kwargs.pop('out_channels', None) if mode == 'proj' or mode == 'attn': if len(aggrs) == 1: raise ValueError("Multiple aggregations are required for " "'proj' or 'attn' combine mode.") - in_channels = mode_kwargs.pop('in_channels', None) - out_channels = mode_kwargs.pop('out_channels', None) - if (in_channels and out_channels) is None: + + if (self.in_channels and self.out_channels) is None: raise ValueError( f"Combine mode '{mode}' must have `in_channels` " f"and `out_channels` specified.") - if isinstance(in_channels, int): - in_channels = (in_channels, ) * len(aggrs) + if isinstance(self.in_channels, int): + self.in_channels = (self.in_channels, ) * len(aggrs) if mode == 'proj': self.lin = Linear( - sum(in_channels), - out_channels, + sum(self.in_channels), + self.out_channels, **mode_kwargs, ) if mode == 'attn': self.lin_heads = torch.nn.ModuleList([ - Linear(channels, out_channels) for channels in in_channels + Linear(channels, self.out_channels) + for channels in self.in_channels ]) num_heads = mode_kwargs.pop('num_heads', 1) self.multihead_attn = MultiheadAttention( - out_channels, + self.out_channels, num_heads, **mode_kwargs, ) @@ -114,6 +116,14 @@ def reset_parameters(self): if hasattr(self, 'multihead_attn'): self.multihead_attn._reset_parameters() + def get_out_channels(self, in_channels: int) -> int: + if self.out_channels is not None: + return self.out_channels + # TODO: Support having customized `out_channels` in each aggregation + if self.mode == 'cat': + return in_channels * len(self.aggrs) + return in_channels + def forward(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: diff --git a/torch_geometric/nn/conv/sage_conv.py b/torch_geometric/nn/conv/sage_conv.py index d9515cc2b9e0..8252399ff65c 100644 --- a/torch_geometric/nn/conv/sage_conv.py +++ b/torch_geometric/nn/conv/sage_conv.py @@ -1,10 +1,11 @@ -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union import torch.nn.functional as F from torch import Tensor from torch.nn import LSTM from torch_sparse import SparseTensor, matmul +from torch_geometric.nn.aggr import Aggregation, MultiAggregation from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.dense.linear import Linear from torch_geometric.typing import Adj, OptPairTensor, Size @@ -66,7 +67,7 @@ def __init__( self, in_channels: Union[int, Tuple[int, int]], out_channels: int, - aggr: str = 'mean', + aggr: Optional[Union[str, List[str], Aggregation]] = "mean", normalize: bool = False, root_weight: bool = True, project: bool = False, @@ -83,8 +84,9 @@ def __init__( in_channels = (in_channels, in_channels) if aggr == 'lstm': - kwargs['aggr_kwargs'] = dict(in_channels=in_channels[0], - out_channels=in_channels[0]) + kwargs.setdefault('aggr_kwargs', {}) + kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0]) + kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0]) super().__init__(aggr, **kwargs) @@ -95,7 +97,13 @@ def __init__( self.fuse = False # No "fused" message_and_aggregate. self.lstm = LSTM(in_channels[0], in_channels[0], batch_first=True) - self.lin_l = Linear(in_channels[0], out_channels, bias=bias) + if isinstance(self.aggr_module, MultiAggregation): + aggr_out_channels = self.aggr_module.get_out_channels( + in_channels[0]) + else: + aggr_out_channels = in_channels[0] + + self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias) if self.root_weight: self.lin_r = Linear(in_channels[1], out_channels, bias=False)