Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TopkPooling into SelectTopK #7308

Merged
merged 17 commits into from
May 8, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Unify graph pooling framework ([#7308](https://github.com/pyg-team/pytorch_geometric/pull/7308))
- Added support for tuples as keys in `ModuleDict`/`ParameterDict` ([#7294](https://github.com/pyg-team/pytorch_geometric/pull/7294))
- Added `NodePropertySplit` transform for creating node-level splits using structural node properties ([#6894](https://github.com/pyg-team/pytorch_geometric/pull/6894))
- Added an option to preserve directed graphs in `CitationFull` datasets ([#7275](https://github.com/pyg-team/pytorch_geometric/pull/7275))
Expand Down
57 changes: 57 additions & 0 deletions test/nn/pool/select/test_select_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch

from torch_geometric.nn.pool.select import SelectOutput, SelectTopK
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.testing import is_full_test


def test_topk_ratio():
x = torch.Tensor([2, 4, 5, 6, 2, 9])
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm1 = topk(x, 0.5, batch)
assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm2 = topk(x, 2, batch)
assert perm2.tolist() == [1, 0, 5, 3]
assert x[perm2].tolist() == [4, 2, 9, 6]
assert batch[perm2].tolist() == [0, 0, 1, 1]

perm3 = topk(x, 3, batch)
assert perm3.tolist() == [1, 0, 5, 3, 2]
assert x[perm3].tolist() == [4, 2, 9, 6, 5]
assert batch[perm3].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 2, batch), perm2)
assert torch.equal(jit(x, 3, batch), perm3)


@pytest.mark.parametrize('min_score', [None, 2.0])
def test_select_topk(min_score):
if min_score is not None:
return
x = torch.randn(6, 16)
batch = torch.tensor([0, 0, 1, 1, 1, 1])

pool = SelectTopK(16, min_score=min_score)

if min_score is None:
assert str(pool) == 'SelectTopK(16, ratio=0.5)'
else:
assert str(pool) == 'SelectTopK(16, min_score=2.0)'

out = pool(x, batch)
assert isinstance(out, SelectOutput)

assert out.num_nodes == 6
assert out.num_clusters <= out.num_nodes
assert out.node_index.min() >= 0
assert out.node_index.max() < out.num_nodes
assert out.cluster_index.min() == 0
assert out.cluster_index.max() == out.num_clusters - 1
29 changes: 2 additions & 27 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
import torch

from torch_geometric.nn.pool.topk_pool import TopKPooling, filter_adj, topk
from torch_geometric.nn.pool import TopKPooling
from torch_geometric.nn.pool.topk_pool import filter_adj
from torch_geometric.testing import is_full_test


def test_topk():
x = torch.Tensor([2, 4, 5, 6, 2, 9])
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm1 = topk(x, 0.5, batch)
assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm2 = topk(x, 2, batch)
assert perm2.tolist() == [1, 0, 5, 3]
assert x[perm2].tolist() == [4, 2, 9, 6]
assert batch[perm2].tolist() == [0, 0, 1, 1]

perm3 = topk(x, 3, batch)
assert perm3.tolist() == [1, 0, 5, 3, 2]
assert x[perm3].tolist() == [4, 2, 9, 6, 5]
assert batch[perm3].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 2, batch), perm2)
assert torch.equal(jit(x, 3, batch), perm3)


def test_filter_adj():
edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3],
[1, 3, 0, 2, 1, 3, 0, 2]])
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/pool/asap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.nn import Linear

from torch_geometric.nn import LEConv
from torch_geometric.nn.pool.topk_pool import topk
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.utils import (
add_remaining_self_loops,
remove_self_loops,
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/pool/pan_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.pool.topk_pool import filter_adj, topk
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.nn.pool.topk_pool import filter_adj
from torch_geometric.typing import OptTensor, SparseTensor
from torch_geometric.utils import scatter, softmax

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/pool/sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from torch import Tensor

from torch_geometric.nn import GraphConv
from torch_geometric.nn.pool.topk_pool import filter_adj, topk
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.nn.pool.topk_pool import filter_adj
from torch_geometric.typing import OptTensor
from torch_geometric.utils import softmax

Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/pool/select/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .base import Select, SelectOutput
from .topk import SelectTopK

__all__ = [
'Select',
'SelectOutput',
'SelectTopK',
]
4 changes: 4 additions & 0 deletions torch_geometric/nn/pool/select/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@ class SelectOutput(CastMixin):

Args:
node_index (torch.Tensor): The indices of the selected nodes.
num_nodes (int): The number of nodes.
cluster_index (torch.Tensor): The indices of the clusters each node in
:obj:`node_index` is assigned to.
num_clusters (int): The number of clusters.
weight (torch.Tensor, optional): A weight vector, denoting the strength
of the assignment of a node to its cluster. (default: :obj:`None`)
"""
node_index: Tensor
num_nodes: int
cluster_index: Tensor
num_clusters: int
weight: Optional[Tensor] = None

def __init__(
self,
node_index: Tensor,
num_nodes: int,
cluster_index: Tensor,
num_clusters: int,
weight: Optional[Tensor] = None,
Expand Down Expand Up @@ -55,6 +58,7 @@ def __init__(
f"values (got {weight.numel()} values)")

self.node_index = node_index
self.num_nodes = num_nodes
self.cluster_index = cluster_index
self.num_clusters = num_clusters
self.weight = weight
Expand Down
135 changes: 135 additions & 0 deletions torch_geometric/nn/pool/select/topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.inits import uniform
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.utils import scatter, softmax

from .base import Select, SelectOutput


# TODO (matthias) Benchmark and document this method.
def topk(
x: Tensor,
ratio: Optional[Union[float, int]],
batch: Tensor,
min_score: Optional[float] = None,
tol: float = 1e-7,
) -> Tensor:
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter(x, batch, reduce='max')[batch] - tol
scores_min = scores_max.clamp(max=min_score)

perm = (x > scores_min).nonzero().view(-1)
return perm

if ratio is not None:
num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
batch_size, max_num_nodes = num_nodes.size(0), int(num_nodes.max())

cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)

index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)

dense_x = x.new_full((batch_size * max_num_nodes, ), -60000.0)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)

_, perm = dense_x.sort(dim=-1, descending=True)

perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)

if ratio >= 1:
k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))
k = torch.min(k, num_nodes)
else:
k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)

if isinstance(ratio, int) and (k == ratio).all():
# If all graphs have exactly `ratio` or more than `ratio` entries,
# we can just pick the first entries in `perm` batch-wise:
index = torch.arange(batch_size, device=x.device) * max_num_nodes
index = index.view(-1, 1).repeat(1, ratio).view(-1)
index += torch.arange(ratio, device=x.device).repeat(batch_size)
else:
# Otherwise, compute indices per graph:
index = torch.cat([
torch.arange(k[i], device=x.device) + i * max_num_nodes
for i in range(batch_size)
], dim=0)

perm = perm[index]
return perm

raise ValueError("At least one of the 'ratio' and 'min_score' parameters "
"must be specified")


class SelectTopK(Select):
# TODO (matthias) Add documentation.
def __init__(
self,
in_channels: int,
ratio: Union[int, float] = 0.5,
min_score: Optional[float] = None,
act: Union[str, Callable] = 'tanh',
):
super().__init__()

if ratio is None and min_score is None:
raise ValueError(f"At least one of the 'ratio' and 'min_score' "
f"parameters must be specified in "
f"'{self.__class__.__name__}'")

self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.act = activation_resolver(act)

self.weight = torch.nn.Parameter(torch.Tensor(1, in_channels))

self.reset_parameters()

def reset_parameters(self):
uniform(self.in_channels, self.weight)

def forward(
self,
x: Tensor,
batch: Optional[Tensor] = None,
) -> SelectOutput:
""""""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)

x.view(-1, 1) if x.dim() == 1 else x
score = (x * self.weight).sum(dim=-1)

if self.min_score is None:
score = self.act(score / self.weight.norm(p=2, dim=-1))
else:
score = softmax(score, batch)

node_index = topk(score, self.ratio, batch, self.min_score)

return SelectOutput(
node_index=node_index,
num_nodes=x.size(0),
cluster_index=torch.arange(node_index.size(0), device=x.device),
num_clusters=node_index.size(0),
weight=score[node_index],
)

def __repr__(self) -> str:
if self.min_score is None:
arg = f'ratio={self.ratio}'
else:
arg = f'min_score={self.min_score}'
return f'{self.__class__.__name__}({self.in_channels}, {arg})'
Loading