From 6897dd7a73512ea2171490b27ca8ebe987da3888 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 3 Jun 2022 14:59:44 +0200 Subject: [PATCH 1/9] update --- README.md | 2 +- test/nn/aggr/test_lstm.py | 2 - test/nn/{glob => aggr}/test_set2set.py | 0 torch_geometric/nn/aggr/__init__.py | 2 + torch_geometric/nn/aggr/base.py | 50 ++++++++++---- torch_geometric/nn/aggr/lstm.py | 19 +----- torch_geometric/nn/aggr/set2set.py | 69 ++++++++++++++++++++ torch_geometric/nn/glob/__init__.py | 10 ++- torch_geometric/nn/glob/set2set.py | 90 -------------------------- 9 files changed, 119 insertions(+), 125 deletions(-) rename test/nn/{glob => aggr}/test_set2set.py (100%) create mode 100644 torch_geometric/nn/aggr/set2set.py delete mode 100644 torch_geometric/nn/glob/set2set.py diff --git a/README.md b/README.md index 55b2ca9c896f..ec81e8886913 100644 --- a/README.md +++ b/README.md @@ -254,7 +254,7 @@ It is commonly applied to graph-level tasks, which require combining node featur Expand to see all implemented pooling layers... * **[GlobalAttention](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.GlobalAttention)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)] -* **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)] +* **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.aggr.Set2Set)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)] * **[Sort Pool](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.glob.global_sort_pool)** from Zhang *et al.*: [An End-to-End Deep Learning Architecture for Graph Classification](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) (AAAI 2018) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)] * **[MinCUT Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.mincut_pool.dense_mincut_pool)** from Bianchi *et al.*: [MinCUT Pooling in Graph Neural Networks](https://arxiv.org/abs/1907.00481) (CoRR 2019) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py)] * **[DMoN Pooling](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.dense.dmon_pool.DMoNPooling)** from Tsitsulin *et al.*: [Graph Clustering with Graph Neural Networks](https://arxiv.org/abs/2006.16904) (CoRR 2020) [[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_dmon_pool.py)] diff --git a/test/nn/aggr/test_lstm.py b/test/nn/aggr/test_lstm.py index 216bc8c377e4..0ec27bf7e05f 100644 --- a/test/nn/aggr/test_lstm.py +++ b/test/nn/aggr/test_lstm.py @@ -11,8 +11,6 @@ def test_lstm_aggregation(): aggr = LSTMAggregation(16, 32) assert str(aggr) == 'LSTMAggregation(16, 32)' - aggr.reset_parameters() - with pytest.raises(ValueError, match="is not sorted"): aggr(x, torch.tensor([0, 1, 0, 1, 2, 1])) diff --git a/test/nn/glob/test_set2set.py b/test/nn/aggr/test_set2set.py similarity index 100% rename from test/nn/glob/test_set2set.py rename to test/nn/aggr/test_set2set.py diff --git a/torch_geometric/nn/aggr/__init__.py b/torch_geometric/nn/aggr/__init__.py index dc43e7bbda21..adc1e7b67e85 100644 --- a/torch_geometric/nn/aggr/__init__.py +++ b/torch_geometric/nn/aggr/__init__.py @@ -10,6 +10,7 @@ PowerMeanAggregation, ) from .lstm import LSTMAggregation +from .set2set import Set2Set __all__ = classes = [ 'Aggregation', @@ -22,4 +23,5 @@ 'SoftmaxAggregation', 'PowerMeanAggregation', 'LSTMAggregation', + 'Set2Set', ] diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 2c721fcda2a6..8c085cbd0252 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -1,15 +1,15 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Tuple import torch from torch import Tensor from torch_scatter import scatter, segment_csr +from torch_geometric.utils import to_dense_batch + class Aggregation(torch.nn.Module, ABC): r"""An abstract base class for implementing custom aggregations.""" - requires_sorted_index = False - @abstractmethod def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, @@ -39,16 +39,6 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: - if index is None and ptr is None: - raise ValueError(f"Expected that either 'index' or 'ptr' is " - f"passed to '{self.__class__.__name__}'") - - if (self.requires_sorted_index and index is not None - and not torch.all(index[:-1] <= index[1:])): - raise ValueError(f"Can not perform aggregation inside " - f"'{self.__class__.__name__}' since the " - f"'index' tensor is not sorted") - if dim >= x.dim() or dim < -x.dim(): raise ValueError(f"Encountered invalid dimension '{dim}' of " f"source tensor with {x.dim()} dimensions") @@ -58,8 +48,25 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *, raise ValueError(f"Encountered mismatch between 'dim_size' (got " f"'{dim_size}') and 'ptr' (got '{ptr.size(0)}')") + if index is None and ptr is None: + index = x.new_zeros(x.size(dim), dtype=torch.long) + + if dim_size is None and ptr is not None: + dim_size = ptr.numel() - 1 + elif dim_size is None: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim) + def assert_two_dimensional_input(self, x: Tensor, dim: int = -2): + if x.dim() != 2: + raise ValueError(f"'{self.__class__.__name__}' requires " + f"two-dimensional inputs (got '{x.dim()}')") + + if dim not in [-2, 0]: + raise ValueError(f"'{self.__class__.__name__}' needs to perform " + f"aggregation in first dimension (got '{dim}')") + def reduce(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, reduce: str = 'add') -> Tensor: @@ -71,6 +78,23 @@ def reduce(self, x: Tensor, index: Optional[Tensor] = None, assert index is not None return scatter(x, index, dim=dim, dim_size=dim_size, reduce=reduce) + def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2) -> Tuple[Tensor, Tensor]: + + if index is None: # TODO + raise NotImplementedError(f"'{self.__class__.__name__}' with " + f"'ptr' not yet supported") + + if torch.all(index[:-1] <= index[1:]): + raise ValueError(f"Can not perform aggregation inside " + f"'{self.__class__.__name__}' since the " + f"'index' tensor is not sorted") + + self.assert_two_dimensional_input(x, dim) + return to_dense_batch(x, index, batch_size=dim_size) + def __repr__(self) -> str: return f'{self.__class__.__name__}()' diff --git a/torch_geometric/nn/aggr/lstm.py b/torch_geometric/nn/aggr/lstm.py index a617d2c49daa..966e4cda1cba 100644 --- a/torch_geometric/nn/aggr/lstm.py +++ b/torch_geometric/nn/aggr/lstm.py @@ -4,7 +4,6 @@ from torch.nn import LSTM from torch_geometric.nn.aggr import Aggregation -from torch_geometric.utils import to_dense_batch class LSTMAggregation(Aggregation): @@ -22,13 +21,12 @@ class LSTMAggregation(Aggregation): out_channels (int): Size of each output sample. **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`. """ - requires_sorted_index = True - def __init__(self, in_channels: int, out_channels: int, **kwargs): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs) + self.reset_parameters() def reset_parameters(self): self.lstm.reset_parameters() @@ -36,20 +34,7 @@ def reset_parameters(self): def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: - - if index is None: # TODO - raise NotImplementedError(f"'{self.__class__.__name__}' with " - f"'ptr' not yet supported") - - if x.dim() != 2: - raise ValueError(f"'{self.__class__.__name__}' requires " - f"two-dimensional inputs (got '{x.dim()}')") - - if dim not in [-2, 0]: - raise ValueError(f"'{self.__class__.__name__}' needs to perform " - f"aggregation in first dimension (got '{dim}')") - - x, _ = to_dense_batch(x, index, batch_size=dim_size) + x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim) return self.lstm(x)[0][:, -1] def __repr__(self) -> str: diff --git a/torch_geometric/nn/aggr/set2set.py b/torch_geometric/nn/aggr/set2set.py new file mode 100644 index 000000000000..69be1aa6d893 --- /dev/null +++ b/torch_geometric/nn/aggr/set2set.py @@ -0,0 +1,69 @@ +from typing import Optional + +import torch +from torch import Tensor + +from torch_geometric.nn.aggr import Aggregation +from torch_geometric.utils import softmax + + +class Set2Set(Aggregation): + r"""The Set2Set aggregation operator based on iterative content-based + attention, as described in the the `"Order Matters: Sequence to sequence + for Sets" `_ paper + + .. math:: + \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) + + \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) + + \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i + + \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, + + where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice + the dimensionality as the input. + + Args: + in_channels (int): Size of each input sample. + processing_steps (int): Number of iterations :math:`T`. + **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`. + """ + def __init__(self, in_channels: int, processing_steps: int, **kwargs): + super().__init__() + self.in_channels = in_channels + self.out_channels = 2 * in_channels + self.processing_steps = processing_steps + self.lstm = torch.nn.LSTM(self.out_channels, in_channels, **kwargs) + self.reset_parameters() + + def reset_parameters(self): + self.lstm.reset_parameters() + + def forward(self, x: Tensor, index: Optional[Tensor] = None, *, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + + self.assert_two_dimensional_input(x, dim) + + if index is None: # TODO + raise NotImplementedError(f"'{self.__class__.__name__}' with " + f"'ptr' not yet supported") + + h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))), + x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1)))) + q_star = x.new_zeros(dim_size, self.out_channels) + + for _ in range(self.processing_steps): + q, h = self.lstm(q_star.unsqueeze(0), h) + q = q.view(dim_size, self.in_channels) + e = (x * q[index]).sum(dim=-1, keepdim=True) + a = softmax(e, index, ptr, dim_size, dim) + r = self.reduce(a * x, index, ptr, dim_size, dim, reduce='add') + q_star = torch.cat([q, r], dim=-1) + + return q_star + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels})') diff --git a/torch_geometric/nn/glob/__init__.py b/torch_geometric/nn/glob/__init__.py index 8b911ccf859e..be0b138ffcd2 100644 --- a/torch_geometric/nn/glob/__init__.py +++ b/torch_geometric/nn/glob/__init__.py @@ -2,7 +2,6 @@ from .glob import GlobalPooling from .sort import global_sort_pool from .attention import GlobalAttention -from .set2set import Set2Set from .gmt import GraphMultisetTransformer __all__ = [ @@ -12,8 +11,15 @@ 'GlobalPooling', 'global_sort_pool', 'GlobalAttention', - 'Set2Set', 'GraphMultisetTransformer', ] classes = __all__ + +from torch_geometric.deprecation import deprecated # noqa +from torch_geometric.nn.aggr import Set2Set # noqa + +Set2Set = deprecated( + details="use 'nn.aggr.Set2Set' instead", + func_name='nn.glob.Set2Set', +)(Set2Set) diff --git a/torch_geometric/nn/glob/set2set.py b/torch_geometric/nn/glob/set2set.py deleted file mode 100644 index 383548da7138..000000000000 --- a/torch_geometric/nn/glob/set2set.py +++ /dev/null @@ -1,90 +0,0 @@ -from typing import Optional - -import torch -from torch import Tensor -from torch_scatter import scatter_add - -from torch_geometric.utils import softmax - - -class Set2Set(torch.nn.Module): - r"""The global pooling operator based on iterative content-based attention - from the `"Order Matters: Sequence to sequence for sets" - `_ paper - - .. math:: - \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) - - \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) - - \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i - - \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, - - where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice - the dimensionality as the input. - - Args: - in_channels (int): Size of each input sample. - processing_steps (int): Number of iterations :math:`T`. - num_layers (int, optional): Number of recurrent layers, *.e.g*, setting - :obj:`num_layers=2` would mean stacking two LSTMs together to form - a stacked LSTM, with the second LSTM taking in outputs of the first - LSTM and computing the final results. (default: :obj:`1`) - - Shapes: - - **input:** - node features :math:`(|\mathcal{V}|, F)`, - batch vector :math:`(|\mathcal{V}|)` *(optional)* - - **output:** graph features :math:`(|\mathcal{G}|, 2 * F)` where - :math:`|\mathcal{G}|` denotes the number of graphs in the batch - """ - def __init__(self, in_channels: int, processing_steps: int, - num_layers: int = 1): - super().__init__() - - self.in_channels = in_channels - self.out_channels = 2 * in_channels - self.processing_steps = processing_steps - self.num_layers = num_layers - - self.lstm = torch.nn.LSTM(self.out_channels, self.in_channels, - num_layers) - - self.reset_parameters() - - def reset_parameters(self): - self.lstm.reset_parameters() - - def forward(self, x: Tensor, batch: Optional[Tensor] = None, - size: Optional[int] = None) -> Tensor: - r""" - Args: - x (Tensor): The input node features. - batch (LongTensor, optional): A vector that maps each node to its - respective graph identifier. (default: :obj:`None`) - size (int, optional): The number of graphs in the batch. - (default: :obj:`None`) - """ - if batch is None: - batch = x.new_zeros(x.size(0), dtype=torch.int64) - - size = int(batch.max()) + 1 if size is None else size - - h = (x.new_zeros((self.num_layers, size, self.in_channels)), - x.new_zeros((self.num_layers, size, self.in_channels))) - q_star = x.new_zeros(size, self.out_channels) - - for _ in range(self.processing_steps): - q, h = self.lstm(q_star.unsqueeze(0), h) - q = q.view(size, self.in_channels) - e = (x * q.index_select(0, batch)).sum(dim=-1, keepdim=True) - a = softmax(e, batch, num_nodes=size) - r = scatter_add(a * x, batch, dim=0, dim_size=size) - q_star = torch.cat([q, r], dim=-1) - - return q_star - - def __repr__(self) -> str: - return (f'{self.__class__.__name__}({self.in_channels}, ' - f'{self.out_channels})') From 65f6d1d3fe1aee7908c1c7f2dd62b81a39521ee3 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 3 Jun 2022 15:02:34 +0200 Subject: [PATCH 2/9] update --- CHANGELOG.md | 2 +- torch_geometric/nn/aggr/set2set.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddd3bcbeb156..26238f529e92 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755)) - Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926)) - Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723)) -- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731)) +- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762)) - Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700)) - Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487)) - Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730)) diff --git a/torch_geometric/nn/aggr/set2set.py b/torch_geometric/nn/aggr/set2set.py index 69be1aa6d893..2286adda4af4 100644 --- a/torch_geometric/nn/aggr/set2set.py +++ b/torch_geometric/nn/aggr/set2set.py @@ -9,8 +9,8 @@ class Set2Set(Aggregation): r"""The Set2Set aggregation operator based on iterative content-based - attention, as described in the the `"Order Matters: Sequence to sequence - for Sets" `_ paper + attention, as described in the `"Order Matters: Sequence to sequence for + Sets" `_ paper .. math:: \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) From cc3066aab7cbe6e37e51633dafbb7f0ab9576784 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 3 Jun 2022 15:05:17 +0200 Subject: [PATCH 3/9] updatE --- torch_geometric/nn/aggr/base.py | 21 ++++++++++++++------- torch_geometric/nn/aggr/set2set.py | 5 +---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 8c085cbd0252..1a0fc3fe7c52 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -58,6 +58,16 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *, return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim) + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + + # Assertions ############################################################## + + def assert_index_present(self, index: Optional[Tensor]): + if index is None: + raise NotImplementedError(f"'{self.__class__.__name__}' requires " + f"'index' to be specified") + def assert_two_dimensional_input(self, x: Tensor, dim: int = -2): if x.dim() != 2: raise ValueError(f"'{self.__class__.__name__}' requires " @@ -67,6 +77,8 @@ def assert_two_dimensional_input(self, x: Tensor, dim: int = -2): raise ValueError(f"'{self.__class__.__name__}' needs to perform " f"aggregation in first dimension (got '{dim}')") + # Helper methods ########################################################## + def reduce(self, x: Tensor, index: Optional[Tensor] = None, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2, reduce: str = 'add') -> Tensor: @@ -83,21 +95,16 @@ def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tuple[Tensor, Tensor]: - if index is None: # TODO - raise NotImplementedError(f"'{self.__class__.__name__}' with " - f"'ptr' not yet supported") + self.assert_index_present(index) # TODO + self.assert_two_dimensional_input(x, dim) if torch.all(index[:-1] <= index[1:]): raise ValueError(f"Can not perform aggregation inside " f"'{self.__class__.__name__}' since the " f"'index' tensor is not sorted") - self.assert_two_dimensional_input(x, dim) return to_dense_batch(x, index, batch_size=dim_size) - def __repr__(self) -> str: - return f'{self.__class__.__name__}()' - ############################################################################### diff --git a/torch_geometric/nn/aggr/set2set.py b/torch_geometric/nn/aggr/set2set.py index 2286adda4af4..263718cbb07c 100644 --- a/torch_geometric/nn/aggr/set2set.py +++ b/torch_geometric/nn/aggr/set2set.py @@ -44,12 +44,9 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: + self.assert_index_present(index) # TODO self.assert_two_dimensional_input(x, dim) - if index is None: # TODO - raise NotImplementedError(f"'{self.__class__.__name__}' with " - f"'ptr' not yet supported") - h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))), x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1)))) q_star = x.new_zeros(dim_size, self.out_channels) From 6b8344ab7b0b150d0adc493d45cfd01080702ba7 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 3 Jun 2022 15:10:45 +0200 Subject: [PATCH 4/9] update --- torch_geometric/nn/aggr/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index 1a0fc3fe7c52..ea1fd4bcec12 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -98,7 +98,7 @@ def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None, self.assert_index_present(index) # TODO self.assert_two_dimensional_input(x, dim) - if torch.all(index[:-1] <= index[1:]): + if not torch.all(index[:-1] <= index[1:]): raise ValueError(f"Can not perform aggregation inside " f"'{self.__class__.__name__}' since the " f"'index' tensor is not sorted") From 65087655788b3ffbeffcacc4acffba80fc67b58b Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 3 Jun 2022 15:11:14 +0200 Subject: [PATCH 5/9] fix test --- test/nn/aggr/test_basic.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/nn/aggr/test_basic.py b/test/nn/aggr/test_basic.py index 7214742eb56d..a73d53e18d3e 100644 --- a/test/nn/aggr/test_basic.py +++ b/test/nn/aggr/test_basic.py @@ -20,9 +20,6 @@ def test_validate(): aggr = MeanAggregation() - with pytest.raises(ValueError, match="either 'index' or 'ptr'"): - aggr(x) - with pytest.raises(ValueError, match="invalid dimension"): aggr(x, index, dim=-3) From 6327c22510e3ec9ab9c84d749f0e7d7eff0d59c9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 3 Jun 2022 15:16:41 +0200 Subject: [PATCH 6/9] update --- torch_geometric/nn/aggr/base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index ea1fd4bcec12..e8c3a2c6a14c 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -68,7 +68,13 @@ def assert_index_present(self, index: Optional[Tensor]): raise NotImplementedError(f"'{self.__class__.__name__}' requires " f"'index' to be specified") - def assert_two_dimensional_input(self, x: Tensor, dim: int = -2): + def assert_sorted_index(self, index: Optional[Tensor]): + if index is not None and not torch.all(index[:-1] <= index[1:]): + raise ValueError(f"Can not perform aggregation inside " + f"'{self.__class__.__name__}' since the " + f"'index' tensor is not sorted") + + def assert_two_dimensional_input(self, x: Tensor, dim: int): if x.dim() != 2: raise ValueError(f"'{self.__class__.__name__}' requires " f"two-dimensional inputs (got '{x.dim()}')") @@ -96,13 +102,8 @@ def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None, dim: int = -2) -> Tuple[Tensor, Tensor]: self.assert_index_present(index) # TODO + self.assert_sorted_index(index) self.assert_two_dimensional_input(x, dim) - - if not torch.all(index[:-1] <= index[1:]): - raise ValueError(f"Can not perform aggregation inside " - f"'{self.__class__.__name__}' since the " - f"'index' tensor is not sorted") - return to_dense_batch(x, index, batch_size=dim_size) From 6fd8f43f85249a8872f797ae9ad7a6703118d3cd Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 5 Jun 2022 08:47:59 +0200 Subject: [PATCH 7/9] update --- torch_geometric/nn/aggr/base.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index e8c3a2c6a14c..ae24b3d9d11b 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -43,18 +43,24 @@ def __call__(self, x: Tensor, index: Optional[Tensor] = None, *, raise ValueError(f"Encountered invalid dimension '{dim}' of " f"source tensor with {x.dim()} dimensions") - if (ptr is not None and dim_size is not None - and dim_size != ptr.numel() - 1): - raise ValueError(f"Encountered mismatch between 'dim_size' (got " - f"'{dim_size}') and 'ptr' (got '{ptr.size(0)}')") - if index is None and ptr is None: index = x.new_zeros(x.size(dim), dtype=torch.long) - if dim_size is None and ptr is not None: - dim_size = ptr.numel() - 1 - elif dim_size is None: - dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + if ptr is not None: + if dim_size is None: + dim_size = ptr.numel() - 1 + elif dim_size != ptr.numel() - 1: + raise ValueError(f"Encountered invalid 'dim_size' (got " + f"'{dim_size}' but expected " + f"'{ptr.numel() - 1}')") + + if index is not None: + if dim_size is None: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + elif index.numel() > 0 and dim_size <= int(index.max()): + raise ValueError(f"Encountered invalid 'dim_size' (got " + f"'{dim_size}' but expected " + f">= '{int(index.max()) + 1}')") return super().__call__(x, index, ptr=ptr, dim_size=dim_size, dim=dim) From d2b4b8c47531dfa2e3a0cc9aea85a118a77132be Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 5 Jun 2022 08:50:10 +0200 Subject: [PATCH 8/9] add todo --- torch_geometric/nn/aggr/base.py | 6 +++++- torch_geometric/nn/aggr/set2set.py | 3 ++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/torch_geometric/nn/aggr/base.py b/torch_geometric/nn/aggr/base.py index ae24b3d9d11b..5857fba75d7f 100644 --- a/torch_geometric/nn/aggr/base.py +++ b/torch_geometric/nn/aggr/base.py @@ -70,6 +70,8 @@ def __repr__(self) -> str: # Assertions ############################################################## def assert_index_present(self, index: Optional[Tensor]): + # TODO Currently, not all aggregators support `ptr`. This assert helps + # to ensure that we require `index` to be passed to the computation: if index is None: raise NotImplementedError(f"'{self.__class__.__name__}' requires " f"'index' to be specified") @@ -107,9 +109,11 @@ def to_dense_batch(self, x: Tensor, index: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tuple[Tensor, Tensor]: - self.assert_index_present(index) # TODO + # TODO Currently, `to_dense_batch` can only operate on `index`: + self.assert_index_present(index) self.assert_sorted_index(index) self.assert_two_dimensional_input(x, dim) + return to_dense_batch(x, index, batch_size=dim_size) diff --git a/torch_geometric/nn/aggr/set2set.py b/torch_geometric/nn/aggr/set2set.py index 263718cbb07c..3c9cd00974af 100644 --- a/torch_geometric/nn/aggr/set2set.py +++ b/torch_geometric/nn/aggr/set2set.py @@ -44,7 +44,8 @@ def forward(self, x: Tensor, index: Optional[Tensor] = None, *, ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, dim: int = -2) -> Tensor: - self.assert_index_present(index) # TODO + # TODO Currently, `to_dense_batch` can only operate on `index`: + self.assert_index_present(index) self.assert_two_dimensional_input(x, dim) h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))), From bff0d79d0b255ef91eb57a4ba4ff3770d0b053df Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sun, 5 Jun 2022 08:52:27 +0200 Subject: [PATCH 9/9] fix test --- test/nn/aggr/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/nn/aggr/test_basic.py b/test/nn/aggr/test_basic.py index a73d53e18d3e..24227ed25488 100644 --- a/test/nn/aggr/test_basic.py +++ b/test/nn/aggr/test_basic.py @@ -23,7 +23,7 @@ def test_validate(): with pytest.raises(ValueError, match="invalid dimension"): aggr(x, index, dim=-3) - with pytest.raises(ValueError, match="mismatch between"): + with pytest.raises(ValueError, match="invalid 'dim_size'"): aggr(x, ptr=ptr, dim_size=2)