From 1fe877e9d6d39e46165556a037f27c0d383a8e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=BCbotter?= Date: Sat, 5 Oct 2024 11:40:19 +0200 Subject: [PATCH] fix querying opposite vectors (#79) --- activeft/sift.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/activeft/sift.py b/activeft/sift.py index 55ed5d6..fa40375 100644 --- a/activeft/sift.py +++ b/activeft/sift.py @@ -160,12 +160,16 @@ def batch_search( ), "`also_query_opposite` should only be used with inner product indexes." D_, I_, V_ = self.index.search_and_reconstruct(-mean_queries, k) # type: ignore D__, I__, V__ = ( - np.concatenate([D, D_]), - np.concatenate([I, I_]), - np.concatenate([V, V_]), + np.concatenate([D, D_], axis=1), + np.concatenate([I, I_], axis=1), + np.concatenate([V, V_], axis=1), + ) + sorted_indices = np.argsort(-D__)[:, :k] + D, I, V = ( + np.take_along_axis(D__, sorted_indices, axis=1), + np.take_along_axis(I__, sorted_indices, axis=1), + np.take_along_axis(V__, sorted_indices[:, :, np.newaxis], axis=1), ) - sorted_indices = np.argsort(-D__)[:k] - D, I, V = D__[sorted_indices], I__[sorted_indices], V__[sorted_indices] t_faiss = time.time() - t_start if self.only_faiss: