From 5fdecf353c31b3b817384268d2f0178a879cf4e0 Mon Sep 17 00:00:00 2001 From: Thijs Snelleman Date: Mon, 26 Aug 2024 16:58:36 +0200 Subject: [PATCH 01/12] Placing cluster_pool into package --- torch_geometric/nn/pool/cluster_pool.py | 191 ++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 torch_geometric/nn/pool/cluster_pool.py diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py new file mode 100644 index 000000000000..e2c7850bf529 --- /dev/null +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -0,0 +1,191 @@ +from typing import Callable, List, NamedTuple, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import Tensor + +from torch_geometric.utils import coalesce, to_scipy_sparse_matrix, dense_to_sparse, to_dense_adj +from scipy.sparse.csgraph import connected_components + + +class ClusterPooling(torch.nn.Module): + r"""The cluster pooling operator from the `"Edge-Based Graph Component + Pooling" ` paper. + + In short, a score is computed for each edge. + Based on the selected edges, graph clusters are calculated and compressed to one + node using an injective aggregation function (sum). Edges are remapped based on + the node created by each cluster and the original edges. + + Args: + in_channels (int): Size of each input sample. + edge_score_method (function, optional): The function to apply + to compute the edge score from raw edge scores. By default, + this is the tanh over all incoming edges for each node. + This function takes in a :obj:`raw_edge_score` tensor of shape + :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of + nodes :obj:`num_nodes`, and produces a new tensor of the same size + as :obj:`raw_edge_score` describing normalized edge scores. + Included functions are + :func:`ClusterPooling.compute_edge_score_tanh`, + :func:`ClusterPooling.compute_edge_score_sigmoid` and + :func:`ClusterPooling.compute_edge_score_logsoftmax`. + (default: :func:`ClusterPooling.compute_edge_score_tanh`) + dropout (float, optional): The probability with + which to drop edge scores during training. (default: :obj:`0`) + """ + unpool_description = NamedTuple( + "UnpoolDescription", + ["edge_index", "batch", "cluster_map"]) + + def __init__(self, + in_channels: int, + edge_score_method: Optional[Callable] = None, + dropout: Optional[float] = 0.0, + threshold: Optional[float] = None, + directed: bool = False): + super().__init__() + self.in_channels = in_channels + if edge_score_method is None: + edge_score_method = self.compute_edge_score_tanh + if threshold is None: + if edge_score_method is self.compute_edge_score_sigmoid: + threshold = 0.5 + else: + threshold = 0.0 + self.compute_edge_score = edge_score_method + self.threshhold = threshold + self.dropout = dropout + self.directed = directed + self.lin = torch.nn.Linear(2 * in_channels, 1) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.lin.reset_parameters() + + @staticmethod + def compute_edge_score_tanh(raw_edge_score: Tensor): + r"""Normalizes edge scores via hyperbolic tangent application.""" + return torch.tanh(raw_edge_score) + + @staticmethod + def compute_edge_score_sigmoid(raw_edge_score: Tensor): + r"""Normalizes edge scores via sigmoid application.""" + return torch.sigmoid(raw_edge_score) + + @staticmethod + def compute_edge_score_logsoftmax(raw_edge_score: Tensor): + r"""Normalizes edge scores via logsoftmax application.""" + return torch.nn.functional.log_softmax(raw_edge_score, dim=0) + + def forward( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]: + r"""Forward pass. + + Args: + x (Tensor): The node features. + edge_index (LongTensor): The edge indices. + batch (LongTensor): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + each node to a specific example. + + Return types: + * **x** *(Tensor)* - The pooled node features. + * **edge_index** *(LongTensor)* - The coarsened edge indices. + * **batch** *(LongTensor)* - The coarsened batch vector. + * **unpool_info** *(unpool_description)* - Information that is + consumed by :func:`ClusterPooling.unpool` for unpooling. + """ + #First we drop the self edges as those cannot be clustered + msk = edge_index[0] != edge_index[1] + edge_index = edge_index[:,msk] + if not self.directed: + edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1) + # We only evaluate each edge once, so we filter double edges from the list + edge_index = coalesce(edge_index) + + e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1) # Concatenates the source feature with the target features + e = self.lin(e).view(-1) # Apply linear NN on the node pairs (edges) and reshape to 1 dimension + e = F.dropout(e, p=self.dropout, training=self.training) + + e = self.compute_edge_score(e) #Non linear activation function + x, edge_index, batch, unpool_info = self.__merge_edges__( + x, edge_index, batch, e) + + return x, edge_index, batch, unpool_info + + def __merge_edges__( + self, + X: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_score: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]: + edges_contract = edge_index[..., edge_score > self.threshhold] + + adj = to_scipy_sparse_matrix(edges_contract, num_nodes=X.size(0)) + _, cluster_index = connected_components(adj, directed=True, connection="weak") + + cluster_index = torch.tensor(cluster_index, dtype=torch.int64, device=X.device) + C = F.one_hot(cluster_index).type(torch.float) + A = to_dense_adj(edge_index, max_num_nodes=X.size(0)).squeeze(0) + S = to_dense_adj(edge_index, edge_attr=edge_score, max_num_nodes=X.size(0)).squeeze(0) + + A_contract = to_dense_adj(edges_contract, max_num_nodes=X.size(0)).type(torch.int).squeeze(0) + nodes_single = ((A_contract.sum(-1) + A_contract.sum(-2))==0).nonzero() + S[nodes_single,nodes_single] = 1 + + X_new = (S @ C).T @ X + edge_index_new, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0)) + + new_batch = X.new_empty(X_new.size(0), dtype=torch.long) + new_batch = new_batch.scatter_(0, cluster_index, batch) + + unpool_info = self.unpool_description(edge_index=edge_index, + batch=batch, + cluster_map=cluster_index) + + return X_new.to(X.device), edge_index_new.to(X.device), new_batch, unpool_info + + def unpool( + self, + x: Tensor, + unpool_info: NamedTuple, + ) -> Tuple[Tensor, Tensor, Tensor]: + r"""Unpools a previous cluster pooling step. + + For unpooling, :obj:`x` should be of same shape as those produced by + this layer's :func:`forward` function. Then, it will produce an + unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`. + + Args: + x (Tensor): The node features. + unpool_info (unpool_description): Information that has + been produced by :func:`ClusterPooling.forward`. + + Return types: + * **x** *(Tensor)* - The unpooled node features. + * **edge_index** *(LongTensor)* - The new edge indices. + * **batch** *(LongTensor)* - The new batch vector. + """ + # We just copy the cluster feature into every node + node_maps = unpool_info.cluster_map + n_nodes = 0 + for c in node_maps: + node_maps += len(c) + import numpy as np + repack = np.array([-1 for _ in range(n_nodes)]) + for i,c in enumerate(node_maps): + repack[c] = i + new_x = x[repack] + + return new_x, unpool_info.edge_index, unpool_info.batch + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.in_channels})' From 25d7f869edad7f796e8707bf766038845a0ec8ff Mon Sep 17 00:00:00 2001 From: Thijs Snelleman Date: Mon, 26 Aug 2024 16:59:42 +0200 Subject: [PATCH 02/12] Removing unaccessed import --- torch_geometric/nn/pool/cluster_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index e2c7850bf529..93f61406cd6c 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -1,4 +1,4 @@ -from typing import Callable, List, NamedTuple, Optional, Tuple +from typing import Callable, NamedTuple, Optional, Tuple import torch import torch.nn.functional as F From a0ebe4e19b16380211b56cb7f7aeca1869b09e32 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:03:16 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/pool/cluster_pool.py | 68 ++++++++++++++----------- 1 file changed, 38 insertions(+), 30 deletions(-) diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index 93f61406cd6c..c9205e800d09 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -2,10 +2,15 @@ import torch import torch.nn.functional as F +from scipy.sparse.csgraph import connected_components from torch import Tensor -from torch_geometric.utils import coalesce, to_scipy_sparse_matrix, dense_to_sparse, to_dense_adj -from scipy.sparse.csgraph import connected_components +from torch_geometric.utils import ( + coalesce, + dense_to_sparse, + to_dense_adj, + to_scipy_sparse_matrix, +) class ClusterPooling(torch.nn.Module): @@ -16,7 +21,7 @@ class ClusterPooling(torch.nn.Module): Based on the selected edges, graph clusters are calculated and compressed to one node using an injective aggregation function (sum). Edges are remapped based on the node created by each cluster and the original edges. - + Args: in_channels (int): Size of each input sample. edge_score_method (function, optional): The function to apply @@ -34,16 +39,13 @@ class ClusterPooling(torch.nn.Module): dropout (float, optional): The probability with which to drop edge scores during training. (default: :obj:`0`) """ - unpool_description = NamedTuple( - "UnpoolDescription", - ["edge_index", "batch", "cluster_map"]) + unpool_description = NamedTuple("UnpoolDescription", + ["edge_index", "batch", "cluster_map"]) - def __init__(self, - in_channels: int, + def __init__(self, in_channels: int, edge_score_method: Optional[Callable] = None, dropout: Optional[float] = 0.0, - threshold: Optional[float] = None, - directed: bool = False): + threshold: Optional[float] = None, directed: bool = False): super().__init__() self.in_channels = in_channels if edge_score_method is None: @@ -104,42 +106,47 @@ def forward( """ #First we drop the self edges as those cannot be clustered msk = edge_index[0] != edge_index[1] - edge_index = edge_index[:,msk] + edge_index = edge_index[:, msk] if not self.directed: edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1) # We only evaluate each edge once, so we filter double edges from the list edge_index = coalesce(edge_index) - - e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1) # Concatenates the source feature with the target features - e = self.lin(e).view(-1) # Apply linear NN on the node pairs (edges) and reshape to 1 dimension + + e = torch.cat( + [x[edge_index[0]], x[edge_index[1]]], + dim=-1) # Concatenates the source feature with the target features + e = self.lin(e).view( + -1 + ) # Apply linear NN on the node pairs (edges) and reshape to 1 dimension e = F.dropout(e, p=self.dropout, training=self.training) - e = self.compute_edge_score(e) #Non linear activation function + e = self.compute_edge_score(e) #Non linear activation function x, edge_index, batch, unpool_info = self.__merge_edges__( x, edge_index, batch, e) return x, edge_index, batch, unpool_info - + def __merge_edges__( - self, - X: Tensor, - edge_index: Tensor, - batch: Tensor, - edge_score: Tensor - ) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]: + self, X: Tensor, edge_index: Tensor, batch: Tensor, + edge_score: Tensor) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]: edges_contract = edge_index[..., edge_score > self.threshhold] adj = to_scipy_sparse_matrix(edges_contract, num_nodes=X.size(0)) - _, cluster_index = connected_components(adj, directed=True, connection="weak") + _, cluster_index = connected_components(adj, directed=True, + connection="weak") - cluster_index = torch.tensor(cluster_index, dtype=torch.int64, device=X.device) + cluster_index = torch.tensor(cluster_index, dtype=torch.int64, + device=X.device) C = F.one_hot(cluster_index).type(torch.float) A = to_dense_adj(edge_index, max_num_nodes=X.size(0)).squeeze(0) - S = to_dense_adj(edge_index, edge_attr=edge_score, max_num_nodes=X.size(0)).squeeze(0) + S = to_dense_adj(edge_index, edge_attr=edge_score, + max_num_nodes=X.size(0)).squeeze(0) - A_contract = to_dense_adj(edges_contract, max_num_nodes=X.size(0)).type(torch.int).squeeze(0) - nodes_single = ((A_contract.sum(-1) + A_contract.sum(-2))==0).nonzero() - S[nodes_single,nodes_single] = 1 + A_contract = to_dense_adj( + edges_contract, max_num_nodes=X.size(0)).type(torch.int).squeeze(0) + nodes_single = ((A_contract.sum(-1) + + A_contract.sum(-2)) == 0).nonzero() + S[nodes_single, nodes_single] = 1 X_new = (S @ C).T @ X edge_index_new, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0)) @@ -151,7 +158,8 @@ def __merge_edges__( batch=batch, cluster_map=cluster_index) - return X_new.to(X.device), edge_index_new.to(X.device), new_batch, unpool_info + return X_new.to(X.device), edge_index_new.to( + X.device), new_batch, unpool_info def unpool( self, @@ -181,7 +189,7 @@ def unpool( node_maps += len(c) import numpy as np repack = np.array([-1 for _ in range(n_nodes)]) - for i,c in enumerate(node_maps): + for i, c in enumerate(node_maps): repack[c] = i new_x = x[repack] From 37f9d34e7744f1bf285998ffee555eec5528690e Mon Sep 17 00:00:00 2001 From: Thijs Snelleman Date: Mon, 26 Aug 2024 17:25:28 +0200 Subject: [PATCH 04/12] PR fixes --- torch_geometric/nn/pool/cluster_pool.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index c9205e800d09..b400c33e4699 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -18,9 +18,9 @@ class ClusterPooling(torch.nn.Module): Pooling" ` paper. In short, a score is computed for each edge. - Based on the selected edges, graph clusters are calculated and compressed to one - node using an injective aggregation function (sum). Edges are remapped based on - the node created by each cluster and the original edges. + Based on the selected edges, graph clusters are calculated and compressed + to one node using an injective aggregation function (sum). Edges are + remapped based on the node created by each cluster and the original edges. Args: in_channels (int): Size of each input sample. @@ -104,23 +104,23 @@ def forward( * **unpool_info** *(unpool_description)* - Information that is consumed by :func:`ClusterPooling.unpool` for unpooling. """ - #First we drop the self edges as those cannot be clustered + # First we drop the self edges as those cannot be clustered msk = edge_index[0] != edge_index[1] edge_index = edge_index[:, msk] if not self.directed: edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1) - # We only evaluate each edge once, so we filter double edges from the list + # We only evaluate each edge once, remove double edges from the list edge_index = coalesce(edge_index) e = torch.cat( [x[edge_index[0]], x[edge_index[1]]], - dim=-1) # Concatenates the source feature with the target features + dim=-1) # Concatenates source feature with target features e = self.lin(e).view( -1 - ) # Apply linear NN on the node pairs (edges) and reshape to 1 dimension + ) # Apply linear NN on the node pairs (edges) and reshape e = F.dropout(e, p=self.dropout, training=self.training) - e = self.compute_edge_score(e) #Non linear activation function + e = self.compute_edge_score(e) # Non linear activation function x, edge_index, batch, unpool_info = self.__merge_edges__( x, edge_index, batch, e) @@ -182,7 +182,7 @@ def unpool( * **edge_index** *(LongTensor)* - The new edge indices. * **batch** *(LongTensor)* - The new batch vector. """ - # We just copy the cluster feature into every node + # We copy the cluster features into every node node_maps = unpool_info.cluster_map n_nodes = 0 for c in node_maps: From 44976fa87962e70f08597193bb07f148a9e6d3d3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 15:27:03 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/pool/cluster_pool.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index b400c33e4699..846e38a84c75 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -116,8 +116,7 @@ def forward( [x[edge_index[0]], x[edge_index[1]]], dim=-1) # Concatenates source feature with target features e = self.lin(e).view( - -1 - ) # Apply linear NN on the node pairs (edges) and reshape + -1) # Apply linear NN on the node pairs (edges) and reshape e = F.dropout(e, p=self.dropout, training=self.training) e = self.compute_edge_score(e) # Non linear activation function From 9ea43394f213fcbd17403bbdf2def1a329f53fb5 Mon Sep 17 00:00:00 2001 From: Thijs Snelleman Date: Tue, 27 Aug 2024 09:28:56 +0200 Subject: [PATCH 06/12] changelog update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f58d082ed6e8..8765df263fa9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983)) - Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952)) - Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407)) +- Added the `torch_geometric.nn.pool.cluster_pool` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) ### Changed From 06f3e3b3d0faa33cbcdf531bb1dbd3f2428071c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 07:30:28 +0000 Subject: [PATCH 07/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 713c3acc7de5..17a51ff6d252 100644 --- a/README.md +++ b/README.md @@ -383,9 +383,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` | | `cpu` | `cu118` | `cu121` | `cu124` | | ----------- | ----- | ------- | ------- | ------- | -| **Linux** | ✅ | ✅ | ✅ | ✅ | -| **Windows** | ✅ | ✅ | ✅ | ✅ | -| **macOS** | ✅ | | | | +| **Linux** | ✅ | ✅ | ✅ | ✅ | +| **Windows** | ✅ | ✅ | ✅ | ✅ | +| **macOS** | ✅ | | | | #### PyTorch 2.3 @@ -399,9 +399,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` dependin | | `cpu` | `cu118` | `cu121` | | ----------- | ----- | ------- | ------- | -| **Linux** | ✅ | ✅ | ✅ | -| **Windows** | ✅ | ✅ | ✅ | -| **macOS** | ✅ | | | +| **Linux** | ✅ | ✅ | ✅ | +| **Windows** | ✅ | ✅ | ✅ | +| **macOS** | ✅ | | | **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, and PyTorch 2.2.0/2.2.1/2.2.2 (following the same procedure). **For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source. From e0a87c345c8bf9161e070820c38499427c39ad6a Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 10 Sep 2024 02:15:10 +0000 Subject: [PATCH 08/12] update --- CHANGELOG.md | 2 +- test/nn/pool/test_cluster_pool.py | 30 ++++ torch_geometric/nn/pool/__init__.py | 12 +- torch_geometric/nn/pool/cluster_pool.py | 219 +++++++++--------------- torch_geometric/nn/pool/edge_pool.py | 2 +- 5 files changed, 122 insertions(+), 143 deletions(-) create mode 100644 test/nn/pool/test_cluster_pool.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 754f4bf0ff27..1b44281645fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the `torch_geometric.nn.pool.cluster_pool` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) - Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632)) - Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594)) - Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554)) @@ -40,7 +41,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983)) - Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952)) - Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407)) -- Added the `torch_geometric.nn.pool.cluster_pool` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) ### Changed diff --git a/test/nn/pool/test_cluster_pool.py b/test/nn/pool/test_cluster_pool.py new file mode 100644 index 000000000000..7adac6100e54 --- /dev/null +++ b/test/nn/pool/test_cluster_pool.py @@ -0,0 +1,30 @@ +import pytest +import torch + +from torch_geometric.nn import ClusterPooling + + +@pytest.mark.parametrize('edge_score_method', [ + 'tanh', + 'sigmoid', + 'log_softmax', +]) +def test_cluster_pooling(edge_score_method): + x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]]) + edge_index = torch.tensor([ + [0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6], + [1, 2, 3, 6, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0], + ]) + batch = torch.tensor([0, 0, 0, 0, 1, 1, 0]) + + op = ClusterPooling(in_channels=1, edge_score_method=edge_score_method) + assert str(op) == 'ClusterPooling(1)' + op.reset_parameters() + + x, edge_index, batch, unpool_info = op(x, edge_index, batch) + assert x.size(0) <= 7 + assert edge_index.size(0) == 2 + if edge_index.numel() > 0: + assert edge_index.min() >= 0 + assert edge_index.max() < x.size(0) + assert batch.size() == (x.size(0), ) diff --git a/torch_geometric/nn/pool/__init__.py b/torch_geometric/nn/pool/__init__.py index e1665ac0461d..09ef32d8536e 100644 --- a/torch_geometric/nn/pool/__init__.py +++ b/torch_geometric/nn/pool/__init__.py @@ -7,18 +7,19 @@ import torch_geometric.typing from torch_geometric.typing import OptTensor, torch_cluster -from .asap import ASAPooling from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x -from .edge_pool import EdgePooling from .glob import global_add_pool, global_max_pool, global_mean_pool from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex, ApproxMIPSKNNIndex) from .graclus import graclus from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x -from .mem_pool import MemPooling -from .pan_pool import PANPooling -from .sag_pool import SAGPooling from .topk_pool import TopKPooling +from .sag_pool import SAGPooling +from .edge_pool import EdgePooling +from .cluster_pool import ClusterPooling +from .asap import ASAPooling +from .pan_pool import PANPooling +from .mem_pool import MemPooling from .voxel_grid import voxel_grid from .approx_knn import approx_knn, approx_knn_graph @@ -344,6 +345,7 @@ def nearest( 'TopKPooling', 'SAGPooling', 'EdgePooling', + 'ClusterPooling', 'ASAPooling', 'PANPooling', 'MemPooling', diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index 846e38a84c75..1dbc5840b9f5 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -1,93 +1,74 @@ -from typing import Callable, NamedTuple, Optional, Tuple +from typing import NamedTuple, Optional, Tuple import torch import torch.nn.functional as F -from scipy.sparse.csgraph import connected_components from torch import Tensor from torch_geometric.utils import ( - coalesce, dense_to_sparse, + one_hot, to_dense_adj, to_scipy_sparse_matrix, ) +class UnpoolInfo(NamedTuple): + edge_index: Tensor + cluster: Tensor + batch: Tensor + + class ClusterPooling(torch.nn.Module): r"""The cluster pooling operator from the `"Edge-Based Graph Component - Pooling" ` paper. + Pooling" `_ paper. - In short, a score is computed for each edge. + :class:`ClusterPooling` computes a score for each edge. Based on the selected edges, graph clusters are calculated and compressed - to one node using an injective aggregation function (sum). Edges are - remapped based on the node created by each cluster and the original edges. + to one node using the injective :obj:`"sum" aggregation function. + Edges are remapped based on the nodes created by each cluster and the + original edges. Args: in_channels (int): Size of each input sample. - edge_score_method (function, optional): The function to apply - to compute the edge score from raw edge scores. By default, - this is the tanh over all incoming edges for each node. - This function takes in a :obj:`raw_edge_score` tensor of shape - :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of - nodes :obj:`num_nodes`, and produces a new tensor of the same size - as :obj:`raw_edge_score` describing normalized edge scores. - Included functions are - :func:`ClusterPooling.compute_edge_score_tanh`, - :func:`ClusterPooling.compute_edge_score_sigmoid` and - :func:`ClusterPooling.compute_edge_score_logsoftmax`. - (default: :func:`ClusterPooling.compute_edge_score_tanh`) + edge_score_method (str, optional): The function to apply + to compute the edge score from raw edge scores (:obj:`"tanh"`, + "sigmoid", :obj:`"log_softmax"`). (default: :obj:`"tanh"`) dropout (float, optional): The probability with - which to drop edge scores during training. (default: :obj:`0`) + which to drop edge scores during training. (default: :obj:`0.0`) + threshold (float, optional): The threshold of edge scores. If set to + :obj:`None`, will be automatically inferred depending on + :obj:`edge_score_method`. (default: :obj:`None`) """ - unpool_description = NamedTuple("UnpoolDescription", - ["edge_index", "batch", "cluster_map"]) - - def __init__(self, in_channels: int, - edge_score_method: Optional[Callable] = None, - dropout: Optional[float] = 0.0, - threshold: Optional[float] = None, directed: bool = False): + def __init__( + self, + in_channels: int, + edge_score_method: str = 'tanh', + dropout: float = 0.0, + threshold: Optional[float] = None, + ): super().__init__() - self.in_channels = in_channels - if edge_score_method is None: - edge_score_method = self.compute_edge_score_tanh + assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax'] + if threshold is None: - if edge_score_method is self.compute_edge_score_sigmoid: - threshold = 0.5 - else: - threshold = 0.0 - self.compute_edge_score = edge_score_method - self.threshhold = threshold + threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0 + + self.in_channels = in_channels + self.edge_score_method = edge_score_method self.dropout = dropout - self.directed = directed - self.lin = torch.nn.Linear(2 * in_channels, 1) + self.threshhold = threshold - self.reset_parameters() + self.lin = torch.nn.Linear(2 * in_channels, 1) def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.lin.reset_parameters() - @staticmethod - def compute_edge_score_tanh(raw_edge_score: Tensor): - r"""Normalizes edge scores via hyperbolic tangent application.""" - return torch.tanh(raw_edge_score) - - @staticmethod - def compute_edge_score_sigmoid(raw_edge_score: Tensor): - r"""Normalizes edge scores via sigmoid application.""" - return torch.sigmoid(raw_edge_score) - - @staticmethod - def compute_edge_score_logsoftmax(raw_edge_score: Tensor): - r"""Normalizes edge scores via logsoftmax application.""" - return torch.nn.functional.log_softmax(raw_edge_score, dim=0) - def forward( self, x: Tensor, edge_index: Tensor, batch: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]: + ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: r"""Forward pass. Args: @@ -104,95 +85,61 @@ def forward( * **unpool_info** *(unpool_description)* - Information that is consumed by :func:`ClusterPooling.unpool` for unpooling. """ - # First we drop the self edges as those cannot be clustered - msk = edge_index[0] != edge_index[1] - edge_index = edge_index[:, msk] - if not self.directed: - edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1) - # We only evaluate each edge once, remove double edges from the list - edge_index = coalesce(edge_index) - - e = torch.cat( - [x[edge_index[0]], x[edge_index[1]]], - dim=-1) # Concatenates source feature with target features - e = self.lin(e).view( - -1) # Apply linear NN on the node pairs (edges) and reshape - e = F.dropout(e, p=self.dropout, training=self.training) - - e = self.compute_edge_score(e) # Non linear activation function - x, edge_index, batch, unpool_info = self.__merge_edges__( - x, edge_index, batch, e) - - return x, edge_index, batch, unpool_info - - def __merge_edges__( - self, X: Tensor, edge_index: Tensor, batch: Tensor, - edge_score: Tensor) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]: - edges_contract = edge_index[..., edge_score > self.threshhold] - - adj = to_scipy_sparse_matrix(edges_contract, num_nodes=X.size(0)) - _, cluster_index = connected_components(adj, directed=True, - connection="weak") - - cluster_index = torch.tensor(cluster_index, dtype=torch.int64, - device=X.device) - C = F.one_hot(cluster_index).type(torch.float) - A = to_dense_adj(edge_index, max_num_nodes=X.size(0)).squeeze(0) - S = to_dense_adj(edge_index, edge_attr=edge_score, - max_num_nodes=X.size(0)).squeeze(0) + mask = edge_index[0] != edge_index[1] + edge_index = edge_index[:, mask] - A_contract = to_dense_adj( - edges_contract, max_num_nodes=X.size(0)).type(torch.int).squeeze(0) - nodes_single = ((A_contract.sum(-1) + - A_contract.sum(-2)) == 0).nonzero() - S[nodes_single, nodes_single] = 1 - - X_new = (S @ C).T @ X - edge_index_new, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0)) + edge_attr = torch.cat( + [x[edge_index[0]], x[edge_index[1]]], + dim=-1, + ) + edge_score = self.lin(edge_attr).view(-1) + edge_score = F.dropout(edge_score, p=self.dropout, + training=self.training) + + if self.edge_score_method == 'tanh': + edge_score = edge_score.tanh() + elif self.edge_score_method == 'sigmoid': + edge_score = edge_score.sigmoid() + else: + assert self.edge_score_method == 'log_softmax' + edge_score = F.log_softmax(edge_score, dim=0) + + return self._merge_edges(x, edge_index, batch, edge_score) + + def _merge_edges( + self, + x: Tensor, + edge_index: Tensor, + batch: Tensor, + edge_score: Tensor, + ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: - new_batch = X.new_empty(X_new.size(0), dtype=torch.long) - new_batch = new_batch.scatter_(0, cluster_index, batch) + from scipy.sparse.csgraph import connected_components - unpool_info = self.unpool_description(edge_index=edge_index, - batch=batch, - cluster_map=cluster_index) + edge_contract = edge_index[:, edge_score > self.threshhold] - return X_new.to(X.device), edge_index_new.to( - X.device), new_batch, unpool_info + adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0)) + _, cluster_np = connected_components(adj, directed=True, + connection="weak") - def unpool( - self, - x: Tensor, - unpool_info: NamedTuple, - ) -> Tuple[Tensor, Tensor, Tensor]: - r"""Unpools a previous cluster pooling step. + cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device) + C = one_hot(cluster) + A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0) + S = to_dense_adj(edge_index, edge_attr=edge_score, + max_num_nodes=x.size(0)).squeeze(0) - For unpooling, :obj:`x` should be of same shape as those produced by - this layer's :func:`forward` function. Then, it will produce an - unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`. + A_contract = to_dense_adj(edge_contract, + max_num_nodes=x.size(0)).squeeze(0) + nodes_single = ((A_contract.sum(dim=-1) + + A_contract.sum(dim=-2)) == 0).nonzero() + S[nodes_single, nodes_single] = 1.0 - Args: - x (Tensor): The node features. - unpool_info (unpool_description): Information that has - been produced by :func:`ClusterPooling.forward`. + x_out = (S @ C).t() @ x + edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0)) + batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch) + unpool_info = UnpoolInfo(edge_index, cluster, batch) - Return types: - * **x** *(Tensor)* - The unpooled node features. - * **edge_index** *(LongTensor)* - The new edge indices. - * **batch** *(LongTensor)* - The new batch vector. - """ - # We copy the cluster features into every node - node_maps = unpool_info.cluster_map - n_nodes = 0 - for c in node_maps: - node_maps += len(c) - import numpy as np - repack = np.array([-1 for _ in range(n_nodes)]) - for i, c in enumerate(node_maps): - repack[c] = i - new_x = x[repack] - - return new_x, unpool_info.edge_index, unpool_info.batch + return x_out, edge_index_out, batch_out, unpool_info def __repr__(self) -> str: return f'{self.__class__.__name__}({self.in_channels})' diff --git a/torch_geometric/nn/pool/edge_pool.py b/torch_geometric/nn/pool/edge_pool.py index 7d7c9db36e89..a9270a16edf4 100644 --- a/torch_geometric/nn/pool/edge_pool.py +++ b/torch_geometric/nn/pool/edge_pool.py @@ -58,7 +58,7 @@ def __init__( self, in_channels: int, edge_score_method: Optional[Callable] = None, - dropout: Optional[float] = 0.0, + dropout: float = 0.0, add_to_edge_score: float = 0.5, ): super().__init__() From 4ed02bc66af98efa4903827184053e79fc20b453 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 10 Sep 2024 02:16:19 +0000 Subject: [PATCH 09/12] update --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b44281645fe..0f8bb52a74d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- Added the `torch_geometric.nn.pool.cluster_pool` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) +- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627)) - Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632)) - Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594)) - Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554)) From d74377e8334318361641fc3a543bfc651759b2f0 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 10 Sep 2024 02:34:41 +0000 Subject: [PATCH 10/12] update --- test/nn/pool/test_cluster_pool.py | 2 ++ torch_geometric/nn/pool/cluster_pool.py | 17 ++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/nn/pool/test_cluster_pool.py b/test/nn/pool/test_cluster_pool.py index 7adac6100e54..4ebfa5c58b10 100644 --- a/test/nn/pool/test_cluster_pool.py +++ b/test/nn/pool/test_cluster_pool.py @@ -2,8 +2,10 @@ import torch from torch_geometric.nn import ClusterPooling +from torch_geometric.testing import withPackage +@withPackage('scipy') @pytest.mark.parametrize('edge_score_method', [ 'tanh', 'sigmoid', diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index 1dbc5840b9f5..83d88004ad8d 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -32,7 +32,7 @@ class ClusterPooling(torch.nn.Module): in_channels (int): Size of each input sample. edge_score_method (str, optional): The function to apply to compute the edge score from raw edge scores (:obj:`"tanh"`, - "sigmoid", :obj:`"log_softmax"`). (default: :obj:`"tanh"`) + :obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`) dropout (float, optional): The probability with which to drop edge scores during training. (default: :obj:`0.0`) threshold (float, optional): The threshold of edge scores. If set to @@ -72,18 +72,17 @@ def forward( r"""Forward pass. Args: - x (Tensor): The node features. - edge_index (LongTensor): The edge indices. - batch (LongTensor): Batch vector + x (torch.Tensor): The node features. + edge_index (torch.Tensor): The edge indices. + batch (torch.Tensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Return types: - * **x** *(Tensor)* - The pooled node features. - * **edge_index** *(LongTensor)* - The coarsened edge indices. - * **batch** *(LongTensor)* - The coarsened batch vector. - * **unpool_info** *(unpool_description)* - Information that is - consumed by :func:`ClusterPooling.unpool` for unpooling. + * **x** *(torch.Tensor)* - The pooled node features. + * **edge_index** *(torch.Tensor)* - The coarsened edge indices. + * **batch** *(torch.Tensor)* - The coarsened batch vector. + * **unpool_info** *(UnpoolInfo)* - Information that for unpooling. """ mask = edge_index[0] != edge_index[1] edge_index = edge_index[:, mask] From 6f03d668373e504b492df99ce8f2cfe17075d5f9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 10 Sep 2024 04:37:55 +0200 Subject: [PATCH 11/12] Migrate feature store encoders (#150) --- torch_geometric/nn/pool/cluster_pool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index 83d88004ad8d..ae414eb81729 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -24,7 +24,7 @@ class ClusterPooling(torch.nn.Module): :class:`ClusterPooling` computes a score for each edge. Based on the selected edges, graph clusters are calculated and compressed - to one node using the injective :obj:`"sum" aggregation function. + to one node using the injective :obj:`"sum"` aggregation function. Edges are remapped based on the nodes created by each cluster and the original edges. From 826f97dcc9af5f106754f907e637073dee62ef69 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 10 Sep 2024 02:45:22 +0000 Subject: [PATCH 12/12] update --- torch_geometric/nn/pool/cluster_pool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/nn/pool/cluster_pool.py b/torch_geometric/nn/pool/cluster_pool.py index ae414eb81729..4cee967b7eaf 100644 --- a/torch_geometric/nn/pool/cluster_pool.py +++ b/torch_geometric/nn/pool/cluster_pool.py @@ -82,7 +82,8 @@ def forward( * **x** *(torch.Tensor)* - The pooled node features. * **edge_index** *(torch.Tensor)* - The coarsened edge indices. * **batch** *(torch.Tensor)* - The coarsened batch vector. - * **unpool_info** *(UnpoolInfo)* - Information that for unpooling. + * **unpool_info** *(UnpoolInfo)* - Information that can be consumed + for unpooling. """ mask = edge_index[0] != edge_index[1] edge_index = edge_index[:, mask]