Skip to content

Commit

Permalink
Add KNNIndex exclusion logic (#8573)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 8, 2023
1 parent 9dea80a commit 64fc4c1
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `KNNIndex` exclusion logic ([#8573](https://github.com/pyg-team/pytorch_geometric/pull/8573))
- Added warning when calling `dataset.num_classes` on regression problems ([#8550](https://github.com/pyg-team/pytorch_geometric/pull/8550))
- Added relabel node functionality to `dropout_node` ([#8524](https://github.com/pyg-team/pytorch_geometric/pull/8524))
- Added support for type checking via `mypy` ([#8254](https://github.com/pyg-team/pytorch_geometric/pull/8254))
Expand Down
37 changes: 35 additions & 2 deletions test/nn/pool/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ def test_L2_knn(device, k):
assert index.get_emb().device == device
assert torch.equal(index.get_emb(), rhs)

out = index.search(lhs, k=k)
out = index.search(lhs, k)
assert out.score.device == device
assert out.index.device == device
assert out.score.size() == (10, k)
assert out.index.size() == (10, k)

mat = torch.linalg.norm(lhs.unsqueeze(1) - rhs.unsqueeze(0), dim=-1).pow(2)
score, index = mat.sort(dim=-1)
Expand All @@ -38,12 +40,43 @@ def test_MIPS_knn(device, k):
assert index.get_emb().device == device
assert torch.equal(index.get_emb(), rhs)

out = index.search(lhs, k=k)
out = index.search(lhs, k)
assert out.score.device == device
assert out.index.device == device
assert out.score.size() == (10, k)
assert out.index.size() == (10, k)

mat = lhs @ rhs.t()
score, index = mat.sort(dim=-1, descending=True)

assert torch.allclose(out.score, score[:, :k])
assert torch.equal(out.index, index[:, :k])


@withCUDA
@withPackage('faiss')
@pytest.mark.parametrize('k', [50])
def test_MIPS_exclude(device, k):
lhs = torch.randn(10, 16, device=device)
rhs = torch.randn(100, 16, device=device)

exclude_lhs = torch.randint(0, 10, (500, ), device=device)
exclude_rhs = torch.randint(0, 100, (500, ), device=device)
exclude_links = torch.stack([exclude_lhs, exclude_rhs], dim=0)
exclude_links = exclude_links.unique(dim=1)

index = MIPSKNNIndex(rhs)

out = index.search(lhs, k, exclude_links)
assert out.score.device == device
assert out.index.device == device
assert out.score.size() == (10, k)
assert out.index.size() == (10, k)

# Ensure that excluded links are not present in `out.index`:
batch = torch.arange(lhs.size(0), device=device).repeat_interleave(k)
knn_links = torch.stack([batch, out.index.view(-1)], dim=0)
knn_links = knn_links[:, knn_links[1] >= 0]

unique_links = torch.cat([knn_links, exclude_links], dim=1).unique(dim=1)
assert unique_links.size(1) == knn_links.size(1) + exclude_links.size(1)
83 changes: 81 additions & 2 deletions torch_geometric/nn/pool/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch import Tensor

from torch_geometric.utils import cumsum, degree, to_dense_batch


class KNNOutput(NamedTuple):
score: Tensor
Expand Down Expand Up @@ -74,14 +76,24 @@ def add(self, emb: Tensor):
self.numel += emb.size(0)
self.index.add(emb.detach())

def search(self, emb: Tensor, k: int) -> KNNOutput:
def search(
self,
emb: Tensor,
k: int,
exclude_links: Optional[Tensor] = None,
) -> KNNOutput:
r"""Search for the :math:`k` nearest neighbors of the given data
points. Returns the distance/similarity score of the nearest neighbors
and their indices.
Args:
emb (torch.Tensor): The data points to add.
k (int): The number of nearest neighbors to return.
exclude_links (torch.Tensor): The links to exclude from searching.
Needs to be a COO tensor of shape :obj:`[2, num_links]`, where
:obj:`exclude_links[0]` refers to indices in :obj:`emb`, and
:obj:`exclude_links[1]` refers to the data points in the
:class:`KNNIndex`. (default: :obj:`None`)
"""
if self.index is None:
raise RuntimeError(f"'{self.__class__.__name__}' is not yet "
Expand All @@ -91,7 +103,74 @@ def search(self, emb: Tensor, k: int) -> KNNOutput:
raise ValueError(f"'emb' needs to be two-dimensional "
f"(got {emb.dim()} dimensions)")

return KNNOutput(*self.index.search(emb.detach(), k))
query_k = k

if exclude_links is not None:
deg = degree(exclude_links[0], num_nodes=emb.size(0)).max()
query_k = k + int(deg.max() if deg.numel() > 0 else 0)

query_k = min(query_k, self.numel)

if query_k > 2048: # `faiss` supports up-to `k=2048`:
warnings.warn(f"Capping 'k' to faiss' upper limit of 2048 "
f"(got {k}). This may cause some relevant items to "
f"not be retrieved.")
query_k = 2048

score, index = self.index.search(emb.detach(), query_k)

if exclude_links is not None:
# Drop indices to exclude by converting to flat vector:
flat_exclude = self.numel * exclude_links[0] + exclude_links[1]

offset = torch.arange(
start=0,
end=self.numel * index.size(0),
step=self.numel,
device=index.device,
).view(-1, 1)
flat_index = (index + offset).view(-1)

notin = torch.isin(flat_index, flat_exclude).logical_not_()

score = score.view(-1)[notin]
index = index.view(-1)[notin]

# Only maintain top-k scores:
count = notin.view(-1, query_k).sum(dim=1)
cum_count = cumsum(count)

batch = torch.arange(count.numel(), device=count.device)
batch = batch.repeat_interleave(count, output_size=cum_count[-1])

batch_arange = torch.arange(count.sum(), device=count.device)
batch_arange = batch_arange - cum_count[batch]

mask = batch_arange < k
score = score[mask]
index = index[mask]

if count.min() < k: # Fill with dummy scores:
batch = batch[mask]
score, _ = to_dense_batch(
score,
batch,
fill_value=float('-inf'),
max_num_nodes=k,
batch_size=emb.size(0),
)
index, _ = to_dense_batch(
index,
batch,
fill_value=-1,
max_num_nodes=k,
batch_size=emb.size(0),
)

score = score.view(-1, k)
index = index.view(-1, k)

return KNNOutput(score, index)

def get_emb(self) -> Tensor:
r"""Returns the data points stored in the :class:`KNNIndex`."""
Expand Down

0 comments on commit 64fc4c1

Please sign in to comment.