Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the ClusterPooling layer #9627

Merged
merged 13 commits into from
Sep 10, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))
- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))
- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))
- Added the `torch_geometric.nn.pool.cluster_pool` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))

### Changed

Expand Down
198 changes: 198 additions & 0 deletions torch_geometric/nn/pool/cluster_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from typing import Callable, NamedTuple, Optional, Tuple

Check warning on line 1 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L1

Added line #L1 was not covered by tests

import torch
import torch.nn.functional as F
from scipy.sparse.csgraph import connected_components
from torch import Tensor

Check warning on line 6 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L3-L6

Added lines #L3 - L6 were not covered by tests

from torch_geometric.utils import (

Check warning on line 8 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L8

Added line #L8 was not covered by tests
coalesce,
dense_to_sparse,
to_dense_adj,
to_scipy_sparse_matrix,
)


class ClusterPooling(torch.nn.Module):

Check warning on line 16 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L16

Added line #L16 was not covered by tests
r"""The cluster pooling operator from the `"Edge-Based Graph Component
Pooling" <paper url>` paper.

In short, a score is computed for each edge.
Based on the selected edges, graph clusters are calculated and compressed
to one node using an injective aggregation function (sum). Edges are
remapped based on the node created by each cluster and the original edges.

Args:
in_channels (int): Size of each input sample.
edge_score_method (function, optional): The function to apply
to compute the edge score from raw edge scores. By default,
this is the tanh over all incoming edges for each node.
This function takes in a :obj:`raw_edge_score` tensor of shape
:obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of
nodes :obj:`num_nodes`, and produces a new tensor of the same size
as :obj:`raw_edge_score` describing normalized edge scores.
Included functions are
:func:`ClusterPooling.compute_edge_score_tanh`,
:func:`ClusterPooling.compute_edge_score_sigmoid` and
:func:`ClusterPooling.compute_edge_score_logsoftmax`.
(default: :func:`ClusterPooling.compute_edge_score_tanh`)
dropout (float, optional): The probability with
which to drop edge scores during training. (default: :obj:`0`)
"""
unpool_description = NamedTuple("UnpoolDescription",

Check warning on line 42 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L42

Added line #L42 was not covered by tests
["edge_index", "batch", "cluster_map"])

def __init__(self, in_channels: int,

Check warning on line 45 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L45

Added line #L45 was not covered by tests
edge_score_method: Optional[Callable] = None,
dropout: Optional[float] = 0.0,
threshold: Optional[float] = None, directed: bool = False):
super().__init__()
self.in_channels = in_channels
if edge_score_method is None:
edge_score_method = self.compute_edge_score_tanh
if threshold is None:
if edge_score_method is self.compute_edge_score_sigmoid:
threshold = 0.5

Check warning on line 55 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L49-L55

Added lines #L49 - L55 were not covered by tests
else:
threshold = 0.0
self.compute_edge_score = edge_score_method
self.threshhold = threshold
self.dropout = dropout
self.directed = directed
self.lin = torch.nn.Linear(2 * in_channels, 1)

Check warning on line 62 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L57-L62

Added lines #L57 - L62 were not covered by tests

self.reset_parameters()

Check warning on line 64 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L64

Added line #L64 was not covered by tests

def reset_parameters(self):

Check warning on line 66 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L66

Added line #L66 was not covered by tests
r"""Resets all learnable parameters of the module."""
self.lin.reset_parameters()

Check warning on line 68 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L68

Added line #L68 was not covered by tests

@staticmethod
def compute_edge_score_tanh(raw_edge_score: Tensor):

Check warning on line 71 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L70-L71

Added lines #L70 - L71 were not covered by tests
r"""Normalizes edge scores via hyperbolic tangent application."""
return torch.tanh(raw_edge_score)

Check warning on line 73 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L73

Added line #L73 was not covered by tests

@staticmethod
def compute_edge_score_sigmoid(raw_edge_score: Tensor):

Check warning on line 76 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L75-L76

Added lines #L75 - L76 were not covered by tests
r"""Normalizes edge scores via sigmoid application."""
return torch.sigmoid(raw_edge_score)

Check warning on line 78 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L78

Added line #L78 was not covered by tests

@staticmethod
def compute_edge_score_logsoftmax(raw_edge_score: Tensor):

Check warning on line 81 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L80-L81

Added lines #L80 - L81 were not covered by tests
r"""Normalizes edge scores via logsoftmax application."""
return torch.nn.functional.log_softmax(raw_edge_score, dim=0)

Check warning on line 83 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L83

Added line #L83 was not covered by tests

def forward(

Check warning on line 85 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L85

Added line #L85 was not covered by tests
self,
x: Tensor,
edge_index: Tensor,
batch: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]:
r"""Forward pass.

Args:
x (Tensor): The node features.
edge_index (LongTensor): The edge indices.
batch (LongTensor): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each node to a specific example.

Return types:
* **x** *(Tensor)* - The pooled node features.
* **edge_index** *(LongTensor)* - The coarsened edge indices.
* **batch** *(LongTensor)* - The coarsened batch vector.
* **unpool_info** *(unpool_description)* - Information that is
consumed by :func:`ClusterPooling.unpool` for unpooling.
"""
# First we drop the self edges as those cannot be clustered
msk = edge_index[0] != edge_index[1]
edge_index = edge_index[:, msk]
if not self.directed:
edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=-1)

Check warning on line 111 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L108-L111

Added lines #L108 - L111 were not covered by tests
# We only evaluate each edge once, remove double edges from the list
edge_index = coalesce(edge_index)

Check warning on line 113 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L113

Added line #L113 was not covered by tests

e = torch.cat(

Check warning on line 115 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L115

Added line #L115 was not covered by tests
[x[edge_index[0]], x[edge_index[1]]],
dim=-1) # Concatenates source feature with target features
e = self.lin(e).view(

Check warning on line 118 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L118

Added line #L118 was not covered by tests
-1) # Apply linear NN on the node pairs (edges) and reshape
e = F.dropout(e, p=self.dropout, training=self.training)

Check warning on line 120 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L120

Added line #L120 was not covered by tests

e = self.compute_edge_score(e) # Non linear activation function
x, edge_index, batch, unpool_info = self.__merge_edges__(

Check warning on line 123 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L122-L123

Added lines #L122 - L123 were not covered by tests
x, edge_index, batch, e)

return x, edge_index, batch, unpool_info

Check warning on line 126 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L126

Added line #L126 was not covered by tests

def __merge_edges__(

Check warning on line 128 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L128

Added line #L128 was not covered by tests
self, X: Tensor, edge_index: Tensor, batch: Tensor,
edge_score: Tensor) -> Tuple[Tensor, Tensor, Tensor, NamedTuple]:
edges_contract = edge_index[..., edge_score > self.threshhold]

Check warning on line 131 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L131

Added line #L131 was not covered by tests

adj = to_scipy_sparse_matrix(edges_contract, num_nodes=X.size(0))
_, cluster_index = connected_components(adj, directed=True,

Check warning on line 134 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L133-L134

Added lines #L133 - L134 were not covered by tests
connection="weak")

cluster_index = torch.tensor(cluster_index, dtype=torch.int64,

Check warning on line 137 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L137

Added line #L137 was not covered by tests
device=X.device)
C = F.one_hot(cluster_index).type(torch.float)
A = to_dense_adj(edge_index, max_num_nodes=X.size(0)).squeeze(0)
S = to_dense_adj(edge_index, edge_attr=edge_score,

Check warning on line 141 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L139-L141

Added lines #L139 - L141 were not covered by tests
max_num_nodes=X.size(0)).squeeze(0)

A_contract = to_dense_adj(

Check warning on line 144 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L144

Added line #L144 was not covered by tests
edges_contract, max_num_nodes=X.size(0)).type(torch.int).squeeze(0)
nodes_single = ((A_contract.sum(-1) +

Check warning on line 146 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L146

Added line #L146 was not covered by tests
A_contract.sum(-2)) == 0).nonzero()
S[nodes_single, nodes_single] = 1

Check warning on line 148 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L148

Added line #L148 was not covered by tests

X_new = (S @ C).T @ X
edge_index_new, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))

Check warning on line 151 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L150-L151

Added lines #L150 - L151 were not covered by tests

new_batch = X.new_empty(X_new.size(0), dtype=torch.long)
new_batch = new_batch.scatter_(0, cluster_index, batch)

Check warning on line 154 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L153-L154

Added lines #L153 - L154 were not covered by tests

unpool_info = self.unpool_description(edge_index=edge_index,

Check warning on line 156 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L156

Added line #L156 was not covered by tests
batch=batch,
cluster_map=cluster_index)

return X_new.to(X.device), edge_index_new.to(

Check warning on line 160 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L160

Added line #L160 was not covered by tests
X.device), new_batch, unpool_info

def unpool(

Check warning on line 163 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L163

Added line #L163 was not covered by tests
self,
x: Tensor,
unpool_info: NamedTuple,
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Unpools a previous cluster pooling step.

For unpooling, :obj:`x` should be of same shape as those produced by
this layer's :func:`forward` function. Then, it will produce an
unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`.

Args:
x (Tensor): The node features.
unpool_info (unpool_description): Information that has
been produced by :func:`ClusterPooling.forward`.

Return types:
* **x** *(Tensor)* - The unpooled node features.
* **edge_index** *(LongTensor)* - The new edge indices.
* **batch** *(LongTensor)* - The new batch vector.
"""
# We copy the cluster features into every node
node_maps = unpool_info.cluster_map
n_nodes = 0
for c in node_maps:
node_maps += len(c)
import numpy as np
repack = np.array([-1 for _ in range(n_nodes)])
for i, c in enumerate(node_maps):
repack[c] = i
new_x = x[repack]

Check warning on line 193 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L185-L193

Added lines #L185 - L193 were not covered by tests

return new_x, unpool_info.edge_index, unpool_info.batch

Check warning on line 195 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L195

Added line #L195 was not covered by tests

def __repr__(self) -> str:
return f'{self.__class__.__name__}({self.in_channels})'

Check warning on line 198 in torch_geometric/nn/pool/cluster_pool.py

View check run for this annotation

Codecov / codecov/patch

torch_geometric/nn/pool/cluster_pool.py#L197-L198

Added lines #L197 - L198 were not covered by tests
Loading