Skip to content

Commit

Permalink
Time (#60)
Browse files Browse the repository at this point in the history
* Added time measurement to the faiss adapter.

* Formatted with black.

* fixes

---------

Co-authored-by: Jonas Hübotter <jonas.huebotter@gmail.com>
  • Loading branch information
Bongni and jonhue authored Aug 29, 2024
1 parent 57821e9 commit 01ff44c
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
31 changes: 23 additions & 8 deletions afsl/adapters/faiss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Tuple
from typing import NamedTuple, Tuple
from afsl.acquisition_functions import AcquisitionFunction, Targeted
import faiss # type: ignore
import torch
import time
import concurrent.futures
import numpy as np
from afsl import ActiveDataLoader
Expand All @@ -19,6 +20,13 @@ def __getitem__(self, index) -> torch.Tensor:
return self.data[index]


class RetrievalTime(NamedTuple):
faiss: float
"""Time spent with Faiss retrieval."""
afsl: float
"""Additional time spent with AFSL."""


class Retriever:
"""
Adapter for the [Faiss](https://github.com/facebookresearch/faiss) library.
Expand Down Expand Up @@ -61,24 +69,24 @@ def search(
k: int | None,
mean_pooling: bool = False,
threads: int = 1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]:
r"""
:param query: Query embedding (of shape $m \times d$), comprised of $m$ individual embeddings.
:param N: Number of results to return.
:param k: Number of results to pre-sample with Faiss. Does not pre-sample if set to `None`.
:param mean_pooling: Whether to use the mean of the query embeddings.
:param threads: Number of threads to use.
:return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), and array of corresponding embeddings (of shape $N \times d$).
:return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), array of corresponding embeddings (of shape $N \times d$), retrieval time.
"""
D, I, V = self.batch_search(
D, I, V, retrieval_time = self.batch_search(
queries=np.array([query]),
N=N,
k=k,
mean_pooling=mean_pooling,
threads=threads,
)
return D[0], I[0], V[0]
return D[0], I[0], V[0], retrieval_time

def batch_search(
self,
Expand All @@ -87,15 +95,15 @@ def batch_search(
k: int | None = None,
mean_pooling: bool = False,
threads: int = 1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]:
r"""
:param queries: $n$ query embeddings (of combined shape $n \times m \times d$), each comprised of $m$ individual embeddings.
:param N: Number of results to return.
:param k: Number of results to pre-sample with Faiss. Does not pre-sample if set to `None`.
:param mean_pooling: Whether to use the mean of the query embeddings.
:param threads: Number of threads to use.
:return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), and array of corresponding embeddings (of shape $n \times N \times d$).
:return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), array of corresponding embeddings (of shape $n \times N \times d$), retrieval time.
"""
assert k is None or k >= N

Expand All @@ -104,11 +112,14 @@ def batch_search(
assert d == self.index.d
mean_queries = np.mean(queries, axis=1)

t_start = time.time()
faiss.omp_set_num_threads(threads) # type: ignore
D, I, V = self.index.search_and_reconstruct(mean_queries, k or self.index.ntotal) # type: ignore
t_faiss = time.time() - t_start

if self.only_faiss:
return D[:, :N], I[:, :N], V[:, :N]
retrieval_time = RetrievalTime(faiss=t_faiss, afsl=0)
return D[:, :N], I[:, :N], V[:, :N], retrieval_time

def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
dataset = Dataset(torch.tensor(V[i]))
Expand All @@ -131,6 +142,7 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
np.array(V[i][sub_indexes]),
)

t_start = time.time()
resulting_values = []
resulting_indices = []
resulting_embeddings = []
Expand All @@ -143,8 +155,11 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
resulting_values.append(values)
resulting_indices.append(indices)
resulting_embeddings.append(embeddings)
t_afsl = time.time() - t_start
retrieval_time = RetrievalTime(faiss=t_faiss, afsl=t_afsl)
return (
np.array(resulting_values),
np.array(resulting_indices),
np.array(resulting_embeddings),
retrieval_time,
)
1 change: 0 additions & 1 deletion examples/fine_tuning/cifar_100/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ def experiment(
faiss_index_path=faiss_index_path,
target_embeddings=target_embeddings,
)
wandb.finish()


def main(args):
Expand Down
1 change: 0 additions & 1 deletion examples/fine_tuning/mnist/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def experiment(
reset_parameters=RESET_PARAMS,
use_best_model=USE_BEST_MODEL,
)
wandb.finish()


def main(args):
Expand Down
2 changes: 1 addition & 1 deletion examples/fine_tuning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def train_loop(
target_embeddings
) # ensure target set is reset to correct length
query = acquisition_function.get_target().cpu().numpy()
_, _batch_indices, _ = retriever.search(
_, _batch_indices, _, _ = retriever.search(
query=query, N=query_batch_size, k=100 * query_batch_size
)
batch_indices = torch.tensor(_batch_indices)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ pytest~=8.3
torch~=2.4
torchvision~=0.19
tqdm~=4.66
wandb
wandb~=0.17

0 comments on commit 01ff44c

Please sign in to comment.