From 97818d6ea9f59cb9a44e4e47214ff0c6af078521 Mon Sep 17 00:00:00 2001 From: wsad1 Date: Sat, 6 May 2023 04:31:29 +0000 Subject: [PATCH 1/5] removed base pooling class --- torch_geometric/nn/pool/base.py | 94 --------------------------------- 1 file changed, 94 deletions(-) delete mode 100644 torch_geometric/nn/pool/base.py diff --git a/torch_geometric/nn/pool/base.py b/torch_geometric/nn/pool/base.py deleted file mode 100644 index 103728056481..000000000000 --- a/torch_geometric/nn/pool/base.py +++ /dev/null @@ -1,94 +0,0 @@ -from dataclasses import dataclass -from typing import Optional - -import torch -from torch import Tensor - -from torch_geometric.nn.aggr import Aggregation -from torch_geometric.nn.pool.connect import Connect -from torch_geometric.nn.pool.select import Select -from torch_geometric.utils.mixin import CastMixin - - -@dataclass -class PoolingOutput(CastMixin): - r"""The pooling output of a :class:`torch_geometric.nn.pool.Pooling` - module. - - Args: - x (torch.Tensor): The pooled node features. - edge_index (torch.Tensor): The coarsened edge indices. - edge_attr (torch.Tensor, optional): The edge features of the coarsened - graph. (default: :obj:`None`) - batch (torch.Tensor, optional): The batch vector of the pooled nodes. - """ - x: Tensor - edge_index: Tensor - edge_attr: Optional[Tensor] = None - batch: Optional[Tensor] = None - - -class Pooling(torch.nn.Module): - r"""A base class for pooling layers based on the - `"Understanding Pooling in Graph Neural Networks" - `_ paper. - - :class:`Pooling` decomposes a pooling layer into three components: - - #. :class:`Select` defines how input nodes map to supernodes. - - #. :class:`Reduce` defines how input node features are aggregated. - - #. :class:`Connect` decides how the supernodes are connected to each other. - - Args: - select (Select): The node selection operator. - reduce (Select): The node feature aggregation operator. - connect (Connect): The edge connection operator. - """ - def __init__(self, select: Select, reduce: Aggregation, connect: Connect): - super().__init__() - self.select = select - self.reduce = reduce - self.connect = connect - - def reset_parameters(self): - r"""Resets all learnable parameters of the module.""" - self.select.reset_parameters() - self.reduce.reset_parameters() - self.connect.reset_parameters() - - def forward( - self, - x: torch.Tensor, - edge_index: torch.Tensor, - edge_attr: Optional[torch.Tensor] = None, - batch: Optional[torch.Tensor] = None, - ) -> PoolingOutput: - r""" - Args: - x (torch.Tensor): The input node features. - edge_index (torch.Tensor): The edge indices. - edge_attr (torch.Tensor, optional): The edge features. - (default: :obj:`None`) - batch (torch.Tensor, optional): The batch vector - :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns - each node to a specific graph. (default: :obj:`None`) - """ - cluster, num_clusters = self.select(x, edge_index, edge_attr, batch) - x = self.reduce(x, cluster, dim_size=num_clusters) - edge_index, edge_attr = self.connect(cluster, edge_index, edge_attr, - batch) - - if batch is not None: - batch = (torch.arange(num_clusters, device=x.device)).scatter_( - 0, cluster, batch) - - return PoolingOutput(x, edge_index, edge_attr, batch) - - def __repr__(self) -> str: - return (f'{self.__class__.__name__}(\n' - f' select={self.select},\n' - f' reduce={self.reduce},\n' - f' connect={self.connect},\n' - f')') From c617ec33f380d79ed46ae193df7ed32517cdf066 Mon Sep 17 00:00:00 2001 From: wsad1 Date: Sat, 6 May 2023 05:07:16 +0000 Subject: [PATCH 2/5] updated select base --- test/nn/pool/test_pooling_base.py | 52 -------------------------- torch_geometric/nn/pool/select/base.py | 24 +++++++++++- 2 files changed, 22 insertions(+), 54 deletions(-) delete mode 100644 test/nn/pool/test_pooling_base.py diff --git a/test/nn/pool/test_pooling_base.py b/test/nn/pool/test_pooling_base.py deleted file mode 100644 index fc93935188a4..000000000000 --- a/test/nn/pool/test_pooling_base.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch - -from torch_geometric.nn import MaxAggregation -from torch_geometric.nn.pool.base import Pooling, PoolingOutput -from torch_geometric.nn.pool.connect import Connect -from torch_geometric.nn.pool.select import Select -from torch_geometric.utils import scatter - - -class DummySelect(Select): - def forward(self, x, edge_index, edge_attr, batch): - # Pool into a single node for each graph. - if batch is None: - return edge_index.new_zeros(x.size(0), dtype=torch.long), 1 - return batch, int(batch.max()) + 1 - - -class DummyConnect(Connect): - def forward(self, x, edge_index, edge_attr, batch): - # Return empty graph connection: - if edge_attr is not None: - edge_attr = edge_attr.new_empty((0, ) + edge_attr.size()[1:]) - return edge_index.new_empty(2, 0), edge_attr - - -def test_pooling(): - pool = Pooling(DummySelect(), MaxAggregation(), DummyConnect()) - pool.reset_parameters() - assert str(pool) == ('Pooling(\n' - ' select=DummySelect(),\n' - ' reduce=MaxAggregation(),\n' - ' connect=DummyConnect(),\n' - ')') - - x = torch.randn(10, 8) - edge_index = torch.empty((2, 0), dtype=torch.long) - edge_attr = torch.empty(0, 4) - batch = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) - - out = pool(x, edge_index) - assert isinstance(out, PoolingOutput) - assert torch.allclose(out.x, x.max(dim=0, keepdim=True)[0]) - assert out.edge_index.size() == (2, 0) - assert out.edge_attr is None - assert out.batch is None - - out = pool(x, edge_index, edge_attr, batch) - assert isinstance(out, PoolingOutput) - assert torch.allclose(out.x, scatter(x, batch, reduce='max')) - assert out.edge_index.size() == (2, 0) - assert out.edge_attr.size() == (0, 4) - assert out.batch.tolist() == [0, 1] diff --git a/torch_geometric/nn/pool/select/base.py b/torch_geometric/nn/pool/select/base.py index c23293ca20d9..09683e3f9a75 100644 --- a/torch_geometric/nn/pool/select/base.py +++ b/torch_geometric/nn/pool/select/base.py @@ -1,9 +1,29 @@ -from typing import Optional, Tuple +from dataclasses import dataclass +from typing import Optional import torch from torch import Tensor +@dataclass +class SelectOutput: + """ + Output of the :meth:`Select.forward` method. + Args: + node_index (Tensor): The indices of the selected nodes. + cluster_index (Tensor): The indices of the clusters each node is + assigned to. Same shape as :obj:`node_index`. + num_clusters (int): The number of clusters. + weight (Tensor, optional): A weight Tensor with values in range + `[0, 1]`, denoting assignment weight of a node to a cluster. + Same shape as :obj:`node_index`. (default: :obj:`None`). + """ + node_index: Tensor + cluster_index: Tensor + num_clusters: int + weight: Optional[Tensor] = None + + class Select(torch.nn.Module): r"""An abstract base class implementing custom node selections that map the nodes of an input graph to supernodes of the pooled one. @@ -22,7 +42,7 @@ def forward( edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, - ) -> Tuple[Tensor, int]: + ) -> SelectOutput: r""" Args: x (torch.Tensor): The input node features. From 93d4d307b094044b0fa1d37a5561c300f47b9349 Mon Sep 17 00:00:00 2001 From: wsad1 Date: Sat, 6 May 2023 05:29:59 +0000 Subject: [PATCH 3/5] udpated input to connect --- torch_geometric/nn/pool/connect/base.py | 7 +++++-- torch_geometric/nn/pool/select/__init__.py | 6 ++---- torch_geometric/nn/pool/select/base.py | 3 ++- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/torch_geometric/nn/pool/connect/base.py b/torch_geometric/nn/pool/connect/base.py index 36c02425768e..4a3f59e08525 100644 --- a/torch_geometric/nn/pool/connect/base.py +++ b/torch_geometric/nn/pool/connect/base.py @@ -3,6 +3,8 @@ import torch from torch import Tensor +from torch_geometric.nn.pool.select import SelectOutput + class Connect(torch.nn.Module): r"""An abstract base class implementing custom edge connection operators. @@ -18,14 +20,15 @@ def reset_parameters(self): def forward( self, - cluster: Tensor, + cluster: SelectOutput, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: - cluster (torch.Tensor): The mapping from nodes to supernodes. + cluster (SelectOutput): The output of `Select`, with a mapping from + nodes to clusters. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`) diff --git a/torch_geometric/nn/pool/select/__init__.py b/torch_geometric/nn/pool/select/__init__.py index fe695d7563db..4cea5567f771 100644 --- a/torch_geometric/nn/pool/select/__init__.py +++ b/torch_geometric/nn/pool/select/__init__.py @@ -1,5 +1,3 @@ -from .base import Select +from .base import Select, SelectOutput -__all__ = [ - 'Select', -] +__all__ = ['Select', 'SelectOutput'] diff --git a/torch_geometric/nn/pool/select/base.py b/torch_geometric/nn/pool/select/base.py index 09683e3f9a75..9d51988f43c1 100644 --- a/torch_geometric/nn/pool/select/base.py +++ b/torch_geometric/nn/pool/select/base.py @@ -8,7 +8,8 @@ @dataclass class SelectOutput: """ - Output of the :meth:`Select.forward` method. + Output of the :meth:`Select.forward` method. Contains + a mapping from nodes to clusters, and the number of clusters. Args: node_index (Tensor): The indices of the selected nodes. cluster_index (Tensor): The indices of the clusters each node is From 304afb8927e1b45cdb016059ffe313a0d51e2cba Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 8 May 2023 07:57:10 +0000 Subject: [PATCH 4/5] update --- torch_geometric/nn/pool/connect/base.py | 72 ++++++++++++++++- torch_geometric/nn/pool/select/__init__.py | 5 +- torch_geometric/nn/pool/select/base.py | 89 +++++++++++++--------- 3 files changed, 128 insertions(+), 38 deletions(-) diff --git a/torch_geometric/nn/pool/connect/base.py b/torch_geometric/nn/pool/connect/base.py index 4a3f59e08525..5aabf0dbe780 100644 --- a/torch_geometric/nn/pool/connect/base.py +++ b/torch_geometric/nn/pool/connect/base.py @@ -1,18 +1,64 @@ +from dataclasses import dataclass from typing import Optional, Tuple import torch from torch import Tensor from torch_geometric.nn.pool.select import SelectOutput +from torch_geometric.utils.mixin import CastMixin + + +@dataclass(init=False) +class ConnectOutput(CastMixin): + r"""The output of the :class:`Connect` method, which holds the coarsened + graph structure, and optional pooled edge features and batch vectors. + + Args: + edge_index (torch.Tensor): The edge indices of the cooarsened graph. + edge_attr (torch.Tensor, optional): The pooled edge features of the + coarsened graph. (default: :obj:`None`) + batch (torch.Tensor, optional): The pooled batch vector of the + coarsened graph. (default: :obj:`None`) + """ + edge_index: Tensor + edge_attr: Optional[Tensor] = None + batch: Optional[Tensor] = None + + def __init__( + self, + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + batch: Optional[Tensor] = None, + ): + if edge_index.dim() != 2: + raise ValueError(f"Expected 'edge_index' to be two-dimensional " + f"(got {edge_index.dim()} dimensions)") + + if edge_index.size(0) != 2: + raise ValueError(f"Expected 'edge_index' to have size '2' in the " + f"first dimension (got '{edge_index.size(0)}')") + + if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): + raise ValueError(f"Expected 'edge_index' and 'edge_attr' to " + f"hold the same number of edges (got " + f"{edge_index.size(1)} and {edge_attr.size(0)} " + f"edges)") + + self.edge_index = edge_index + self.edge_attr = edge_attr + self.batch = batch class Connect(torch.nn.Module): - r"""An abstract base class implementing custom edge connection operators. + r"""An abstract base class implementing custom edge connection operators as + described in the `"Understanding Pooling in Graph Neural Networks" + `_ paper. Specifically, :class:`Connect` determines for each pair of supernodes the presence or abscene of an edge based on the existing edges between the nodes in the two supernodes. - The operator also computes new coarsened edge features (if present). + The operator also computes pooled edge features and batch vectors + (if present). """ def reset_parameters(self): r"""Resets all learnable parameters of the module.""" @@ -20,7 +66,7 @@ def reset_parameters(self): def forward( self, - cluster: SelectOutput, + select_output: SelectOutput, edge_index: Tensor, edge_attr: Optional[Tensor] = None, batch: Optional[Tensor] = None, @@ -38,5 +84,25 @@ def forward( """ raise NotImplementedError + @staticmethod + def get_pooled_batch( + select_output: SelectOutput, + batch: Optional[Tensor], + ) -> Optional[Tensor]: + r"""Returns the batch vector of the coarsened graph. + + Args: + select_output (SelectOutput): The output of :class:`Select`. + batch (torch.Tensor, optional): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + each element to a specific example. (default: :obj:`None`) + """ + if batch is None: + return batch + + out = torch.arange(select_output.num_clusters, device=batch.device) + return out.scatter_(0, select_output.cluster_index, + batch[select_output.node_index]) + def __repr__(self) -> str: return f'{self.__class__.__name__}()' diff --git a/torch_geometric/nn/pool/select/__init__.py b/torch_geometric/nn/pool/select/__init__.py index 4cea5567f771..2a7278a33954 100644 --- a/torch_geometric/nn/pool/select/__init__.py +++ b/torch_geometric/nn/pool/select/__init__.py @@ -1,3 +1,6 @@ from .base import Select, SelectOutput -__all__ = ['Select', 'SelectOutput'] +__all__ = [ + 'Select', + 'SelectOutput', +] diff --git a/torch_geometric/nn/pool/select/base.py b/torch_geometric/nn/pool/select/base.py index 9d51988f43c1..ee40f17f055f 100644 --- a/torch_geometric/nn/pool/select/base.py +++ b/torch_geometric/nn/pool/select/base.py @@ -4,56 +4,77 @@ import torch from torch import Tensor +from torch_geometric.utils.mixin import CastMixin + + +@dataclass(init=False) +class SelectOutput(CastMixin): + r"""The output of the :class:`Select` method, which holds an assignment + from selected nodes to their respective cluster(s). -@dataclass -class SelectOutput: - """ - Output of the :meth:`Select.forward` method. Contains - a mapping from nodes to clusters, and the number of clusters. Args: - node_index (Tensor): The indices of the selected nodes. - cluster_index (Tensor): The indices of the clusters each node is - assigned to. Same shape as :obj:`node_index`. + node_index (torch.Tensor): The indices of the selected nodes. + cluster_index (torch.Tensor): The indices of the clusters each node in + :obj:`node_index` is assigned to. num_clusters (int): The number of clusters. - weight (Tensor, optional): A weight Tensor with values in range - `[0, 1]`, denoting assignment weight of a node to a cluster. - Same shape as :obj:`node_index`. (default: :obj:`None`). + weight (torch.Tensor, optional): A weight vector, denoting the strength + of the assignment of a node to its cluster. (default: :obj:`None`) """ node_index: Tensor cluster_index: Tensor num_clusters: int weight: Optional[Tensor] = None + def __init__( + self, + node_index: Tensor, + cluster_index: Tensor, + num_clusters: int, + weight: Optional[Tensor] = None, + ): + if node_index.dim() != 1: + raise ValueError(f"Expected 'node_index' to be one-dimensional " + f"(got {node_index.dim()} dimensions)") + + if cluster_index.dim() != 1: + raise ValueError(f"Expected 'cluster_index' to be one-dimensional " + f"(got {cluster_index.dim()} dimensions)") + + if node_index.numel() != cluster_index.numel(): + raise ValueError(f"Expected 'node_index' and 'cluster_index' to " + f"hold the same number of values (got " + f"{node_index.numel()} and " + f"{cluster_index.numel()} values)") + + if weight is not None and weight.dim() != 1: + raise ValueError(f"Expected 'weight' vector to be one-dimensional " + f"(got {weight.dim()} dimensions)") + + if weight is not None and weight.numel() != node_index.numel(): + raise ValueError(f"Expected 'weight' to hold {node_index.numel()} " + f"values (got {weight.numel()} values)") + + self.node_index = node_index + self.cluster_index = cluster_index + self.num_clusters = num_clusters + self.weight = weight + class Select(torch.nn.Module): - r"""An abstract base class implementing custom node selections that map the - nodes of an input graph to supernodes of the pooled one. + r"""An abstract base class for implementing custom node selections as + described in the `"Understanding Pooling in Graph Neural Networks" + `_ paper, which maps the nodes of an + input graph to supernodes in the coarsened graph. - Specifically, :class:`Select` returns a mapping - :math:`\mathbf{c} \in {\{ -1, \ldots, C - 1\}}^N`, which assigns each node - to one of :math:`C` super nodes. - In addition, :class:`Select` returns the number of super nodes. + Specifically, :class:`Select` returns a :class:`SelectOutput` output, which + holds a (sparse) mapping :math:`\mathbf{C} \in {[0, 1]}^{N \times C}` that + assigns selected nodes to one or more of :math:`C` super nodes. """ def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" pass - def forward( - self, - x: Tensor, - edge_index: Tensor, - edge_attr: Optional[Tensor] = None, - batch: Optional[Tensor] = None, - ) -> SelectOutput: - r""" - Args: - x (torch.Tensor): The input node features. - edge_index (torch.Tensor): The edge indices. - edge_attr (torch.Tensor, optional): The edge features. - (default: :obj:`None`) - batch (torch.Tensor, optional): The batch vector - :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns - each node to a specific graph. (default: :obj:`None`) - """ + def forward(self, *args, **kwargs) -> SelectOutput: raise NotImplementedError def __repr__(self) -> str: From d59f37e72be454198d5a07bcc2f50015705a71cf Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 8 May 2023 07:59:06 +0000 Subject: [PATCH 5/5] update --- torch_geometric/nn/pool/connect/__init__.py | 3 ++- torch_geometric/nn/pool/connect/base.py | 9 ++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/torch_geometric/nn/pool/connect/__init__.py b/torch_geometric/nn/pool/connect/__init__.py index aabde59575b4..52b45639fc63 100644 --- a/torch_geometric/nn/pool/connect/__init__.py +++ b/torch_geometric/nn/pool/connect/__init__.py @@ -1,5 +1,6 @@ -from .base import Connect +from .base import Connect, ConnectOutput __all__ = [ 'Connect', + 'ConnectOutput', ] diff --git a/torch_geometric/nn/pool/connect/base.py b/torch_geometric/nn/pool/connect/base.py index 5aabf0dbe780..1b528e29a9be 100644 --- a/torch_geometric/nn/pool/connect/base.py +++ b/torch_geometric/nn/pool/connect/base.py @@ -50,9 +50,9 @@ def __init__( class Connect(torch.nn.Module): - r"""An abstract base class implementing custom edge connection operators as - described in the `"Understanding Pooling in Graph Neural Networks" - `_ paper. + r"""An abstract base class for implementing custom edge connection + operators as described in the `"Understanding Pooling in Graph Neural + Networks" `_ paper. Specifically, :class:`Connect` determines for each pair of supernodes the presence or abscene of an edge based on the existing edges between the @@ -73,8 +73,7 @@ def forward( ) -> Tuple[Tensor, Optional[Tensor]]: r""" Args: - cluster (SelectOutput): The output of `Select`, with a mapping from - nodes to clusters. + select_output (SelectOutput): The output of :class:`Select`. edge_index (torch.Tensor): The edge indices. edge_attr (torch.Tensor, optional): The edge features. (default: :obj:`None`)