-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Towards #6455. Do not review before #7307 is merged. --------- Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
- Loading branch information
Showing
10 changed files
with
214 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn.pool.select import SelectOutput, SelectTopK | ||
from torch_geometric.nn.pool.select.topk import topk | ||
from torch_geometric.testing import is_full_test | ||
|
||
|
||
def test_topk_ratio(): | ||
x = torch.Tensor([2, 4, 5, 6, 2, 9]) | ||
batch = torch.tensor([0, 0, 1, 1, 1, 1]) | ||
|
||
perm1 = topk(x, 0.5, batch) | ||
assert perm1.tolist() == [1, 5, 3] | ||
assert x[perm1].tolist() == [4, 9, 6] | ||
assert batch[perm1].tolist() == [0, 1, 1] | ||
|
||
perm2 = topk(x, 2, batch) | ||
assert perm2.tolist() == [1, 0, 5, 3] | ||
assert x[perm2].tolist() == [4, 2, 9, 6] | ||
assert batch[perm2].tolist() == [0, 0, 1, 1] | ||
|
||
perm3 = topk(x, 3, batch) | ||
assert perm3.tolist() == [1, 0, 5, 3, 2] | ||
assert x[perm3].tolist() == [4, 2, 9, 6, 5] | ||
assert batch[perm3].tolist() == [0, 0, 1, 1, 1] | ||
|
||
if is_full_test(): | ||
jit = torch.jit.script(topk) | ||
assert torch.equal(jit(x, 0.5, batch), perm1) | ||
assert torch.equal(jit(x, 2, batch), perm2) | ||
assert torch.equal(jit(x, 3, batch), perm3) | ||
|
||
|
||
@pytest.mark.parametrize('min_score', [None, 2.0]) | ||
def test_select_topk(min_score): | ||
if min_score is not None: | ||
return | ||
x = torch.randn(6, 16) | ||
batch = torch.tensor([0, 0, 1, 1, 1, 1]) | ||
|
||
pool = SelectTopK(16, min_score=min_score) | ||
|
||
if min_score is None: | ||
assert str(pool) == 'SelectTopK(16, ratio=0.5)' | ||
else: | ||
assert str(pool) == 'SelectTopK(16, min_score=2.0)' | ||
|
||
out = pool(x, batch) | ||
assert isinstance(out, SelectOutput) | ||
|
||
assert out.num_nodes == 6 | ||
assert out.num_clusters <= out.num_nodes | ||
assert out.node_index.min() >= 0 | ||
assert out.node_index.max() < out.num_nodes | ||
assert out.cluster_index.min() == 0 | ||
assert out.cluster_index.max() == out.num_clusters - 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
from .base import Select, SelectOutput | ||
from .topk import SelectTopK | ||
|
||
__all__ = [ | ||
'Select', | ||
'SelectOutput', | ||
'SelectTopK', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from typing import Callable, Optional, Union | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torch_geometric.nn.inits import uniform | ||
from torch_geometric.nn.resolver import activation_resolver | ||
from torch_geometric.utils import scatter, softmax | ||
|
||
from .base import Select, SelectOutput | ||
|
||
|
||
# TODO (matthias) Benchmark and document this method. | ||
def topk( | ||
x: Tensor, | ||
ratio: Optional[Union[float, int]], | ||
batch: Tensor, | ||
min_score: Optional[float] = None, | ||
tol: float = 1e-7, | ||
) -> Tensor: | ||
if min_score is not None: | ||
# Make sure that we do not drop all nodes in a graph. | ||
scores_max = scatter(x, batch, reduce='max')[batch] - tol | ||
scores_min = scores_max.clamp(max=min_score) | ||
|
||
perm = (x > scores_min).nonzero().view(-1) | ||
return perm | ||
|
||
if ratio is not None: | ||
num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum') | ||
batch_size, max_num_nodes = num_nodes.size(0), int(num_nodes.max()) | ||
|
||
cum_num_nodes = torch.cat( | ||
[num_nodes.new_zeros(1), | ||
num_nodes.cumsum(dim=0)[:-1]], dim=0) | ||
|
||
index = torch.arange(batch.size(0), dtype=torch.long, device=x.device) | ||
index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes) | ||
|
||
dense_x = x.new_full((batch_size * max_num_nodes, ), -60000.0) | ||
dense_x[index] = x | ||
dense_x = dense_x.view(batch_size, max_num_nodes) | ||
|
||
_, perm = dense_x.sort(dim=-1, descending=True) | ||
|
||
perm = perm + cum_num_nodes.view(-1, 1) | ||
perm = perm.view(-1) | ||
|
||
if ratio >= 1: | ||
k = num_nodes.new_full((num_nodes.size(0), ), int(ratio)) | ||
k = torch.min(k, num_nodes) | ||
else: | ||
k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long) | ||
|
||
if isinstance(ratio, int) and (k == ratio).all(): | ||
# If all graphs have exactly `ratio` or more than `ratio` entries, | ||
# we can just pick the first entries in `perm` batch-wise: | ||
index = torch.arange(batch_size, device=x.device) * max_num_nodes | ||
index = index.view(-1, 1).repeat(1, ratio).view(-1) | ||
index += torch.arange(ratio, device=x.device).repeat(batch_size) | ||
else: | ||
# Otherwise, compute indices per graph: | ||
index = torch.cat([ | ||
torch.arange(k[i], device=x.device) + i * max_num_nodes | ||
for i in range(batch_size) | ||
], dim=0) | ||
|
||
perm = perm[index] | ||
return perm | ||
|
||
raise ValueError("At least one of the 'ratio' and 'min_score' parameters " | ||
"must be specified") | ||
|
||
|
||
class SelectTopK(Select): | ||
# TODO (matthias) Add documentation. | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
ratio: Union[int, float] = 0.5, | ||
min_score: Optional[float] = None, | ||
act: Union[str, Callable] = 'tanh', | ||
): | ||
super().__init__() | ||
|
||
if ratio is None and min_score is None: | ||
raise ValueError(f"At least one of the 'ratio' and 'min_score' " | ||
f"parameters must be specified in " | ||
f"'{self.__class__.__name__}'") | ||
|
||
self.in_channels = in_channels | ||
self.ratio = ratio | ||
self.min_score = min_score | ||
self.act = activation_resolver(act) | ||
|
||
self.weight = torch.nn.Parameter(torch.Tensor(1, in_channels)) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
uniform(self.in_channels, self.weight) | ||
|
||
def forward( | ||
self, | ||
x: Tensor, | ||
batch: Optional[Tensor] = None, | ||
) -> SelectOutput: | ||
"""""" | ||
if batch is None: | ||
batch = x.new_zeros(x.size(0), dtype=torch.long) | ||
|
||
x.view(-1, 1) if x.dim() == 1 else x | ||
score = (x * self.weight).sum(dim=-1) | ||
|
||
if self.min_score is None: | ||
score = self.act(score / self.weight.norm(p=2, dim=-1)) | ||
else: | ||
score = softmax(score, batch) | ||
|
||
node_index = topk(score, self.ratio, batch, self.min_score) | ||
|
||
return SelectOutput( | ||
node_index=node_index, | ||
num_nodes=x.size(0), | ||
cluster_index=torch.arange(node_index.size(0), device=x.device), | ||
num_clusters=node_index.size(0), | ||
weight=score[node_index], | ||
) | ||
|
||
def __repr__(self) -> str: | ||
if self.min_score is None: | ||
arg = f'ratio={self.ratio}' | ||
else: | ||
arg = f'min_score={self.min_score}' | ||
return f'{self.__class__.__name__}({self.in_channels}, {arg})' |
Oops, something went wrong.