Skip to content

Commit

Permalink
Refactor TopkPooling into SelectTopK (#7308)
Browse files Browse the repository at this point in the history
Towards #6455.
Do not review before #7307 is merged.

---------

Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
  • Loading branch information
wsad1 and rusty1s authored May 8, 2023
1 parent c566f5c commit a43fc56
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 112 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Unify graph pooling framework ([#7308](https://github.com/pyg-team/pytorch_geometric/pull/7308))
- Added support for tuples as keys in `ModuleDict`/`ParameterDict` ([#7294](https://github.com/pyg-team/pytorch_geometric/pull/7294))
- Added `NodePropertySplit` transform for creating node-level splits using structural node properties ([#6894](https://github.com/pyg-team/pytorch_geometric/pull/6894))
- Added an option to preserve directed graphs in `CitationFull` datasets ([#7275](https://github.com/pyg-team/pytorch_geometric/pull/7275))
Expand Down
57 changes: 57 additions & 0 deletions test/nn/pool/select/test_select_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch

from torch_geometric.nn.pool.select import SelectOutput, SelectTopK
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.testing import is_full_test


def test_topk_ratio():
x = torch.Tensor([2, 4, 5, 6, 2, 9])
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm1 = topk(x, 0.5, batch)
assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm2 = topk(x, 2, batch)
assert perm2.tolist() == [1, 0, 5, 3]
assert x[perm2].tolist() == [4, 2, 9, 6]
assert batch[perm2].tolist() == [0, 0, 1, 1]

perm3 = topk(x, 3, batch)
assert perm3.tolist() == [1, 0, 5, 3, 2]
assert x[perm3].tolist() == [4, 2, 9, 6, 5]
assert batch[perm3].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 2, batch), perm2)
assert torch.equal(jit(x, 3, batch), perm3)


@pytest.mark.parametrize('min_score', [None, 2.0])
def test_select_topk(min_score):
if min_score is not None:
return
x = torch.randn(6, 16)
batch = torch.tensor([0, 0, 1, 1, 1, 1])

pool = SelectTopK(16, min_score=min_score)

if min_score is None:
assert str(pool) == 'SelectTopK(16, ratio=0.5)'
else:
assert str(pool) == 'SelectTopK(16, min_score=2.0)'

out = pool(x, batch)
assert isinstance(out, SelectOutput)

assert out.num_nodes == 6
assert out.num_clusters <= out.num_nodes
assert out.node_index.min() >= 0
assert out.node_index.max() < out.num_nodes
assert out.cluster_index.min() == 0
assert out.cluster_index.max() == out.num_clusters - 1
29 changes: 2 additions & 27 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,10 @@
import torch

from torch_geometric.nn.pool.topk_pool import TopKPooling, filter_adj, topk
from torch_geometric.nn.pool import TopKPooling
from torch_geometric.nn.pool.topk_pool import filter_adj
from torch_geometric.testing import is_full_test


def test_topk():
x = torch.Tensor([2, 4, 5, 6, 2, 9])
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm1 = topk(x, 0.5, batch)
assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm2 = topk(x, 2, batch)
assert perm2.tolist() == [1, 0, 5, 3]
assert x[perm2].tolist() == [4, 2, 9, 6]
assert batch[perm2].tolist() == [0, 0, 1, 1]

perm3 = topk(x, 3, batch)
assert perm3.tolist() == [1, 0, 5, 3, 2]
assert x[perm3].tolist() == [4, 2, 9, 6, 5]
assert batch[perm3].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 2, batch), perm2)
assert torch.equal(jit(x, 3, batch), perm3)


def test_filter_adj():
edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3],
[1, 3, 0, 2, 1, 3, 0, 2]])
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/nn/pool/asap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.nn import Linear

from torch_geometric.nn import LEConv
from torch_geometric.nn.pool.topk_pool import topk
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.utils import (
add_remaining_self_loops,
remove_self_loops,
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/pool/pan_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.pool.topk_pool import filter_adj, topk
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.nn.pool.topk_pool import filter_adj
from torch_geometric.typing import OptTensor, SparseTensor
from torch_geometric.utils import scatter, softmax

Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/pool/sag_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from torch import Tensor

from torch_geometric.nn import GraphConv
from torch_geometric.nn.pool.topk_pool import filter_adj, topk
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.nn.pool.topk_pool import filter_adj
from torch_geometric.typing import OptTensor
from torch_geometric.utils import softmax

Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/pool/select/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .base import Select, SelectOutput
from .topk import SelectTopK

__all__ = [
'Select',
'SelectOutput',
'SelectTopK',
]
4 changes: 4 additions & 0 deletions torch_geometric/nn/pool/select/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@ class SelectOutput(CastMixin):
Args:
node_index (torch.Tensor): The indices of the selected nodes.
num_nodes (int): The number of nodes.
cluster_index (torch.Tensor): The indices of the clusters each node in
:obj:`node_index` is assigned to.
num_clusters (int): The number of clusters.
weight (torch.Tensor, optional): A weight vector, denoting the strength
of the assignment of a node to its cluster. (default: :obj:`None`)
"""
node_index: Tensor
num_nodes: int
cluster_index: Tensor
num_clusters: int
weight: Optional[Tensor] = None

def __init__(
self,
node_index: Tensor,
num_nodes: int,
cluster_index: Tensor,
num_clusters: int,
weight: Optional[Tensor] = None,
Expand Down Expand Up @@ -55,6 +58,7 @@ def __init__(
f"values (got {weight.numel()} values)")

self.node_index = node_index
self.num_nodes = num_nodes
self.cluster_index = cluster_index
self.num_clusters = num_clusters
self.weight = weight
Expand Down
135 changes: 135 additions & 0 deletions torch_geometric/nn/pool/select/topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.nn.inits import uniform
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.utils import scatter, softmax

from .base import Select, SelectOutput


# TODO (matthias) Benchmark and document this method.
def topk(
x: Tensor,
ratio: Optional[Union[float, int]],
batch: Tensor,
min_score: Optional[float] = None,
tol: float = 1e-7,
) -> Tensor:
if min_score is not None:
# Make sure that we do not drop all nodes in a graph.
scores_max = scatter(x, batch, reduce='max')[batch] - tol
scores_min = scores_max.clamp(max=min_score)

perm = (x > scores_min).nonzero().view(-1)
return perm

if ratio is not None:
num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
batch_size, max_num_nodes = num_nodes.size(0), int(num_nodes.max())

cum_num_nodes = torch.cat(
[num_nodes.new_zeros(1),
num_nodes.cumsum(dim=0)[:-1]], dim=0)

index = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)

dense_x = x.new_full((batch_size * max_num_nodes, ), -60000.0)
dense_x[index] = x
dense_x = dense_x.view(batch_size, max_num_nodes)

_, perm = dense_x.sort(dim=-1, descending=True)

perm = perm + cum_num_nodes.view(-1, 1)
perm = perm.view(-1)

if ratio >= 1:
k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))
k = torch.min(k, num_nodes)
else:
k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)

if isinstance(ratio, int) and (k == ratio).all():
# If all graphs have exactly `ratio` or more than `ratio` entries,
# we can just pick the first entries in `perm` batch-wise:
index = torch.arange(batch_size, device=x.device) * max_num_nodes
index = index.view(-1, 1).repeat(1, ratio).view(-1)
index += torch.arange(ratio, device=x.device).repeat(batch_size)
else:
# Otherwise, compute indices per graph:
index = torch.cat([
torch.arange(k[i], device=x.device) + i * max_num_nodes
for i in range(batch_size)
], dim=0)

perm = perm[index]
return perm

raise ValueError("At least one of the 'ratio' and 'min_score' parameters "
"must be specified")


class SelectTopK(Select):
# TODO (matthias) Add documentation.
def __init__(
self,
in_channels: int,
ratio: Union[int, float] = 0.5,
min_score: Optional[float] = None,
act: Union[str, Callable] = 'tanh',
):
super().__init__()

if ratio is None and min_score is None:
raise ValueError(f"At least one of the 'ratio' and 'min_score' "
f"parameters must be specified in "
f"'{self.__class__.__name__}'")

self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.act = activation_resolver(act)

self.weight = torch.nn.Parameter(torch.Tensor(1, in_channels))

self.reset_parameters()

def reset_parameters(self):
uniform(self.in_channels, self.weight)

def forward(
self,
x: Tensor,
batch: Optional[Tensor] = None,
) -> SelectOutput:
""""""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)

x.view(-1, 1) if x.dim() == 1 else x
score = (x * self.weight).sum(dim=-1)

if self.min_score is None:
score = self.act(score / self.weight.norm(p=2, dim=-1))
else:
score = softmax(score, batch)

node_index = topk(score, self.ratio, batch, self.min_score)

return SelectOutput(
node_index=node_index,
num_nodes=x.size(0),
cluster_index=torch.arange(node_index.size(0), device=x.device),
num_clusters=node_index.size(0),
weight=score[node_index],
)

def __repr__(self) -> str:
if self.min_score is None:
arg = f'ratio={self.ratio}'
else:
arg = f'min_score={self.min_score}'
return f'{self.__class__.__name__}({self.in_channels}, {arg})'
Loading

0 comments on commit a43fc56

Please sign in to comment.