Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.aggr.Set2Set #4762

Merged
merged 9 commits into from
Jun 5, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ It is commonly applied to graph-level tasks, which require combining node featur
<summary><b>Expand to see all implemented pooling layers...</b></summary>

* **[GlobalAttention](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.GlobalAttention)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)]
* **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)]
* **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)]
* **[Sort Pool](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_sort_pool)** from Zhang *et al.*: [An End-to-End Deep Learning Architecture for Graph Classification](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) (AAAI 2018) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)]
* **[MinCUT Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.mincut_pool.dense_mincut_pool)** from Bianchi *et al.*: [MinCUT Pooling in Graph Neural Networks](https://arxiv.org/abs/1907.00481) (CoRR 2019) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py)]
* **[DMoN Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.dmon_pool.DMoNPooling)** from Tsitsulin *et al.*: [Graph Clustering with Graph Neural Networks](https://arxiv.org/abs/2006.16904) (CoRR 2020) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_dmon_pool.py)]
Expand Down
3 changes: 0 additions & 3 deletions test/nn/aggr/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ def test_validate():

aggr = MeanAggregation()

with pytest.raises(ValueError, match="either 'index' or 'ptr'"):
aggr(x)

with pytest.raises(ValueError, match="invalid dimension"):
aggr(x, index, dim=-3)

Expand Down
2 changes: 0 additions & 2 deletions test/nn/aggr/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ def test_lstm_aggregation():
aggr = LSTMAggregation(16, 32)
assert str(aggr) == 'LSTMAggregation(16, 32)'

aggr.reset_parameters()

with pytest.raises(ValueError, match="is not sorted"):
aggr(x, torch.tensor([0, 1, 0, 1, 2, 1]))

Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PowerMeanAggregation,
)
from .lstm import LSTMAggregation
from .set2set import Set2Set

__all__ = classes = [
'Aggregation',
Expand All @@ -22,4 +23,5 @@
'SoftmaxAggregation',
'PowerMeanAggregation',
'LSTMAggregation',
'Set2Set',
]
62 changes: 47 additions & 15 deletions torch_geometric/nn/aggr/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch_scatter import scatter, segment_csr

from torch_geometric.utils import to_dense_batch


class Aggregation(torch.nn.Module, ABC):
r"""An abstract base class for implementing custom aggregations."""
requires_sorted_index = False

@abstractmethod
def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
Expand Down Expand Up @@ -39,16 +39,6 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

if index is None and ptr is None:
raise ValueError(f"Expected that either 'index' or 'ptr' is "
f"passed to '{self.__class__.__name__}'")

if (self.requires_sorted_index and index is not None
and not torch.all(index[:-1] <= index[1:])):
raise ValueError(f"Can not perform aggregation inside "
f"'{self.__class__.__name__}' since the "
f"'index' tensor is not sorted")

if dim >= x.dim() or dim < -x.dim():
raise ValueError(f"Encountered invalid dimension '{dim}' of "
f"source tensor with {x.dim()} dimensions")
Expand All @@ -58,8 +48,43 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *,
raise ValueError(f"Encountered mismatch between 'dim_size' (got "
f"'{dim_size}') and 'ptr' (got '{ptr.size(0)}')")

if index is None and ptr is None:
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
index = x.new_zeros(x.size(dim), dtype=torch.long)

if dim_size is None and ptr is not None:
dim_size = ptr.numel() - 1
elif dim_size is None:
dim_size = int(index.max()) + 1 if index.numel() > 0 else 0

rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

# Assertions ##############################################################

def assert_index_present(self, index: Optional[Tensor]):
if index is None:
raise NotImplementedError(f"'{self.__class__.__name__}' requires "
f"'index' to be specified")

def assert_sorted_index(self, index: Optional[Tensor]):
if index is not None and not torch.all(index[:-1] <= index[1:]):
raise ValueError(f"Can not perform aggregation inside "
f"'{self.__class__.__name__}' since the "
f"'index' tensor is not sorted")

def assert_two_dimensional_input(self, x: Tensor, dim: int):
if x.dim() != 2:
raise ValueError(f"'{self.__class__.__name__}' requires "
f"two-dimensional inputs (got '{x.dim()}')")

if dim not in [-2, 0]:
raise ValueError(f"'{self.__class__.__name__}' needs to perform "
f"aggregation in first dimension (got '{dim}')")

# Helper methods ##########################################################

def reduce(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2, reduce: str = 'add') -> Tensor:
Expand All @@ -71,8 +96,15 @@ def reduce(self, x: Tensor, index: Optional[Tensor] = None,
assert index is not None
return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'
def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None,
dim_size: Optional[int] = None,
dim: int = -2) -> Tuple[Tensor, Tensor]:

self.assert_index_present(index) # TODO
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
self.assert_sorted_index(index)
self.assert_two_dimensional_input(x, dim)
return to_dense_batch(x, index, batch_size=dim_size)


###############################################################################
Expand Down
19 changes: 2 additions & 17 deletions torch_geometric/nn/aggr/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch.nn import LSTM

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import to_dense_batch


class LSTMAggregation(Aggregation):
Expand All @@ -22,34 +21,20 @@ class LSTMAggregation(Aggregation):
out_channels (int): Size of each output sample.
**kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.
"""
requires_sorted_index = True

def __init__(self, in_channels: int, out_channels: int, **kwargs):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs)
self.reset_parameters()

def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

if index is None: # TODO
raise NotImplementedError(f"'{self.__class__.__name__}' with "
f"'ptr' not yet supported")

if x.dim() != 2:
raise ValueError(f"'{self.__class__.__name__}' requires "
f"two-dimensional inputs (got '{x.dim()}')")

if dim not in [-2, 0]:
raise ValueError(f"'{self.__class__.__name__}' needs to perform "
f"aggregation in first dimension (got '{dim}')")

x, _ = to_dense_batch(x, index, batch_size=dim_size)
x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim)
return self.lstm(x)[0][:, -1]

def __repr__(self) -> str:
Expand Down
66 changes: 66 additions & 0 deletions torch_geometric/nn/aggr/set2set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Optional

import torch
from torch import Tensor

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import softmax


class Set2Set(Aggregation):
r"""The Set2Set aggregation operator based on iterative content-based
attention, as described in the `"Order Matters: Sequence to sequence for
Sets" <https://arxiv.org/abs/1511.06391>`_ paper

.. math::
\mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1})

\alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t)

\mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i

\mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t,

where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice
the dimensionality as the input.

Args:
in_channels (int): Size of each input sample.
processing_steps (int): Number of iterations :math:`T`.
**kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.
"""
def __init__(self, in_channels: int, processing_steps: int, **kwargs):
super().__init__()
self.in_channels = in_channels
self.out_channels = 2 * in_channels
self.processing_steps = processing_steps
self.lstm = torch.nn.LSTM(self.out_channels, in_channels, **kwargs)
self.reset_parameters()

def reset_parameters(self):
self.lstm.reset_parameters()

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

self.assert_index_present(index) # TODO
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
self.assert_two_dimensional_input(x, dim)

h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))),
x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))))
q_star = x.new_zeros(dim_size, self.out_channels)

for _ in range(self.processing_steps):
q, h = self.lstm(q_star.unsqueeze(0), h)
q = q.view(dim_size, self.in_channels)
e = (x * q[index]).sum(dim=-1, keepdim=True)
a = softmax(e, index, ptr, dim_size, dim)
r = self.reduce(a * x, index, ptr, dim_size, dim, reduce='add')
q_star = torch.cat([q, r], dim=-1)

return q_star

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')
10 changes: 8 additions & 2 deletions torch_geometric/nn/glob/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .glob import GlobalPooling
from .sort import global_sort_pool
from .attention import GlobalAttention
from .set2set import Set2Set
from .gmt import GraphMultisetTransformer

__all__ = [
Expand All @@ -12,8 +11,15 @@
'GlobalPooling',
'global_sort_pool',
'GlobalAttention',
'Set2Set',
'GraphMultisetTransformer',
]

classes = __all__

from torch_geometric.deprecation import deprecated # noqa
from torch_geometric.nn.aggr import Set2Set # noqa

Set2Set = deprecated(
details="use 'nn.aggr.Set2Set' instead",
func_name='nn.glob.Set2Set',
)(Set2Set)
90 changes: 0 additions & 90 deletions torch_geometric/nn/glob/set2set.py

This file was deleted.