Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 8, 2023
1 parent a6f36a4 commit 114b9a4
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 56 deletions.
57 changes: 57 additions & 0 deletions test/nn/pool/select/test_select_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
import torch

from torch_geometric.nn.pool.select import SelectOutput, SelectTopK
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.testing import is_full_test


def test_topk_ratio():
x = torch.Tensor([2, 4, 5, 6, 2, 9])
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm1 = topk(x, 0.5, batch)
assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm2 = topk(x, 2, batch)
assert perm2.tolist() == [1, 0, 5, 3]
assert x[perm2].tolist() == [4, 2, 9, 6]
assert batch[perm2].tolist() == [0, 0, 1, 1]

perm3 = topk(x, 3, batch)
assert perm3.tolist() == [1, 0, 5, 3, 2]
assert x[perm3].tolist() == [4, 2, 9, 6, 5]
assert batch[perm3].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 2, batch), perm2)
assert torch.equal(jit(x, 3, batch), perm3)


@pytest.mark.parametrize('min_score', [None, 2.0])
def test_select_topk(min_score):
if min_score is not None:
return
x = torch.randn(6, 16)
batch = torch.tensor([0, 0, 1, 1, 1, 1])

pool = SelectTopK(16, min_score=min_score)

if min_score is None:
assert str(pool) == 'SelectTopK(16, ratio=0.5)'
else:
assert str(pool) == 'SelectTopK(16, min_score=2.0)'

out = pool(x, batch)
assert isinstance(out, SelectOutput)

assert out.num_nodes == 6
assert out.num_clusters <= out.num_nodes
assert out.node_index.min() >= 0
assert out.node_index.max() < out.num_nodes
assert out.cluster_index.min() == 0
assert out.cluster_index.max() == out.num_clusters - 1
27 changes: 0 additions & 27 deletions test/nn/pool/test_topk_pool.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
import torch

from torch_geometric.nn.pool import TopKPooling
from torch_geometric.nn.pool.select.topk import topk
from torch_geometric.nn.pool.topk_pool import filter_adj
from torch_geometric.testing import is_full_test


def test_topk():
x = torch.Tensor([2, 4, 5, 6, 2, 9])
batch = torch.tensor([0, 0, 1, 1, 1, 1])

perm1 = topk(x, 0.5, batch)
assert perm1.tolist() == [1, 5, 3]
assert x[perm1].tolist() == [4, 9, 6]
assert batch[perm1].tolist() == [0, 1, 1]

perm2 = topk(x, 2, batch)
assert perm2.tolist() == [1, 0, 5, 3]
assert x[perm2].tolist() == [4, 2, 9, 6]
assert batch[perm2].tolist() == [0, 0, 1, 1]

perm3 = topk(x, 3, batch)
assert perm3.tolist() == [1, 0, 5, 3, 2]
assert x[perm3].tolist() == [4, 2, 9, 6, 5]
assert batch[perm3].tolist() == [0, 0, 1, 1, 1]

if is_full_test():
jit = torch.jit.script(topk)
assert torch.equal(jit(x, 0.5, batch), perm1)
assert torch.equal(jit(x, 2, batch), perm2)
assert torch.equal(jit(x, 3, batch), perm3)


def test_filter_adj():
edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3],
[1, 3, 0, 2, 1, 3, 0, 2]])
Expand Down
49 changes: 42 additions & 7 deletions torch_geometric/nn/pool/select/topk.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Optional, Union
from typing import Callable, Optional, Union

import torch
from torch import Tensor

from torch_geometric.utils import scatter
from torch_geometric.nn.inits import uniform
from torch_geometric.nn.resolver import activation_resolver
from torch_geometric.utils import scatter, softmax

from .base import Select, SelectOutput

Expand Down Expand Up @@ -72,12 +74,12 @@ def topk(

class SelectTopK(Select):
# TODO (matthias) Add documentation.
# TODO (matthias) Add separate `SelectTopK` tests.
def __init__(
self,
ratio: Optional[float] = None,
in_channels: int,
ratio: Union[int, float] = 0.5,
min_score: Optional[float] = None,
tol: float = 1e-7,
act: Union[str, Callable] = 'tanh',
):
super().__init__()

Expand All @@ -86,15 +88,48 @@ def __init__(
f"parameters must be specified in "
f"'{self.__class__.__name__}'")

self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.act = activation_resolver(act)

def forward(self, x: Tensor, batch: Tensor) -> SelectOutput:
node_index = topk(x, self.ratio, batch, self.min_score)
self.weight = torch.nn.Parameter(torch.Tensor(1, in_channels))

self.reset_parameters()

def reset_parameters(self):
uniform(self.in_channels, self.weight)

def forward(
self,
x: Tensor,
batch: Optional[Tensor] = None,
) -> SelectOutput:
""""""
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)

x.view(-1, 1) if x.dim() == 1 else x
score = (x * self.weight).sum(dim=-1)

if self.min_score is None:
score = self.act(score / self.weight.norm(p=2, dim=-1))
else:
score = softmax(score, batch)

node_index = topk(score, self.ratio, batch, self.min_score)

return SelectOutput(
node_index=node_index,
num_nodes=x.size(0),
cluster_index=torch.arange(node_index.size(0), device=x.device),
num_clusters=node_index.size(0),
weight=score[node_index],
)

def __repr__(self) -> str:
if self.min_score is None:
arg = f'ratio={self.ratio}'
else:
arg = f'min_score={self.min_score}'
return f'{self.__class__.__name__}({self.in_channels}, {arg})'
29 changes: 7 additions & 22 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.inits import uniform
from torch_geometric.nn.pool.select import SelectTopK
from torch_geometric.utils import softmax
from torch_geometric.utils.num_nodes import maybe_num_nodes


Expand Down Expand Up @@ -96,22 +93,18 @@ def __init__(
):
super().__init__()

if isinstance(nonlinearity, str):
nonlinearity = getattr(torch, nonlinearity)

self.in_channels = in_channels
self.ratio = ratio
self.min_score = min_score
self.multiplier = multiplier
self.nonlinearity = nonlinearity
self.select = SelectTopK(ratio, min_score)
self.weight = Parameter(torch.Tensor(1, in_channels))

self.select = SelectTopK(in_channels, ratio, min_score, nonlinearity)

self.reset_parameters()

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
uniform(self.in_channels, self.weight)
self.select.reset_parameters()

def forward(
self,
Expand All @@ -138,25 +131,17 @@ def forward(
batch = edge_index.new_zeros(x.size(0))

attn = x if attn is None else attn
attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn
score = (attn * self.weight).sum(dim=-1)

if self.min_score is None:
score = self.nonlinearity(score / self.weight.norm(p=2, dim=-1))
else:
score = softmax(score, batch)

select_output = self.select(score, batch)
select_output = self.select(attn, batch)

perm = select_output.node_index
x = x[perm] * score[perm].view(-1, 1)
x = x[perm] * select_output.weight.view(-1, 1)
x = self.multiplier * x if self.multiplier != 1 else x

batch = batch[perm]
edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm,
num_nodes=score.size(0))
num_nodes=select_output.num_nodes)

return x, edge_index, edge_attr, batch, perm, score[perm]
return x, edge_index, edge_attr, batch, perm, select_output.weight

def __repr__(self) -> str:
if self.min_score is None:
Expand Down

0 comments on commit 114b9a4

Please sign in to comment.