Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Dec 3, 2022
1 parent 16bfb85 commit a45a193
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions torch_geometric/nn/pool/topk_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,17 @@ def topk(
if isinstance(ratio, int) and (k == ratio).all():
# If all graphs have exactly `ratio` or more than `ratio` entries,
# we can just pick the first entries in `perm` batch-wise:
mask = torch.arange(batch_size, device=x.device) * max_num_nodes
mask = mask.view(-1, 1).repeat(1, ratio).view(-1)
mask += torch.arange(ratio, device=x.device).repeat(batch_size)
index = torch.arange(batch_size, device=x.device) * max_num_nodes
index = index.view(-1, 1).repeat(1, ratio).view(-1)
index += torch.arange(ratio, device=x.device).repeat(batch_size)
else:
# Otherwise, compute indices per graph:
mask = [
torch.arange(k[i], dtype=torch.long, device=x.device) +
i * max_num_nodes for i in range(batch_size)
]
mask = torch.cat(mask, dim=0)
index = torch.cat([
torch.arange(k[i], device=x.device) + i * max_num_nodes
for i in range(batch_size)
], dim=0)

perm = perm[mask]
perm = perm[index]

else:
raise ValueError("At least one of 'min_score' and 'ratio' parameters "
Expand Down

0 comments on commit a45a193

Please sign in to comment.