Skip to content

Commit

Permalink
Merge pull request #2 from mixedbread-ai/feature/loading-enhancement
Browse files Browse the repository at this point in the history
Feature/loading enhancement
  • Loading branch information
huangrpablo authored Aug 14, 2024
2 parents 2d351d5 + c209c30 commit de1de89
Show file tree
Hide file tree
Showing 14 changed files with 468 additions and 88 deletions.
18 changes: 7 additions & 11 deletions baguetter/indices/dense/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@

import numpy as np

_INDEX_PREFIX = "usearch_index_"
_STATE_PREFIX = "usearch_state_"
_INDEX_PREFIX = "index_"


class BaseDenseIndex(BaseIndex, abc.ABC):
NAME_PREFIX: str = "dense_"

def __init__(
self,
index_name: str = "new-index",
Expand All @@ -29,13 +26,13 @@ def __init__(
super().__init__()

self.index_name: str = index_name
self.n_workers: int = n_workers or max(1, (os.cpu_count() or 1) - 1)
self.n_workers: int = n_workers if n_workers is not None else max(1, (os.cpu_count() or 1) - 1)
self._embed_fn: Callable[[list[str], bool], np.ndarray] | None = embed_fn

@property
def name(self) -> str:
"""Get the full name of the index."""
return f"{self.NAME_PREFIX}{self.index_name}"
return self.index_name

def _embed(self, query: list[str], *, is_query: bool = False, show_progress: bool = False) -> np.ndarray:
"""Embed text queries into vectors.
Expand All @@ -61,22 +58,21 @@ def _embed(self, query: list[str], *, is_query: bool = False, show_progress: boo
return self._embed_fn(query, is_query=is_query, show_progress=show_progress)

@staticmethod
def build_index_file_paths(name_or_path: str) -> tuple[str, str]:
def build_index_file_paths(path: str) -> tuple[str, str]:
"""Build the file paths for the index and state files.
Args:
name_or_path (str): Path to the index.
path (str): Path to the index.
Returns:
Tuple[str, str]: File paths for the state and index files.
"""
path = Path(name_or_path)
state_file_name = f"{_STATE_PREFIX}{path.name}"
path = Path(path)
index_file_name = f"{_INDEX_PREFIX}{path.name}"

dir_name = path.parent if path.parent != Path() else Path()
state_file_path = dir_name / state_file_name
state_file_path = dir_name / path.name
index_file_path = dir_name / index_file_name

return str(state_file_path), str(index_file_path)
22 changes: 8 additions & 14 deletions baguetter/indices/dense/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@

import faiss
import numpy as np
from tqdm import tqdm

from baguetter.indices.base import SearchResults
from baguetter.indices.dense.base import BaseDenseIndex
from baguetter.indices.dense.config import FaissDenseIndexConfig
from baguetter.logger import LOGGER
from baguetter.utils.common import batch_iter

if TYPE_CHECKING:
from collections.abc import Callable
Expand All @@ -28,8 +26,6 @@ def _support_nprobe(index: Any) -> bool:
class FaissDenseIndex(BaseDenseIndex):
"""A dense index implementation using Faiss."""

NAME_PREFIX = "faiss_"

def __init__(
self,
index_name: str = "new-index",
Expand Down Expand Up @@ -100,9 +96,9 @@ def from_config(cls, config: FaissDenseIndexConfig) -> FaissDenseIndex:

def _save(
self,
path: str,
repository: AbstractFileRepository,
path: str | None = None,
) -> None:
) -> str:
"""Save the index state and data."""
state = {
"key_mapping": self.key_mapping,
Expand All @@ -120,19 +116,20 @@ def _save(
with repository.open(index_file_path, "wb") as file:
index = faiss.serialize_index(self.faiss_index)
np.savez_compressed(file, index=index)
return state_file_path

@classmethod
def _load(
cls,
name_or_path: str,
path: str,
*,
repository: AbstractFileRepository,
mmap: bool = False,
) -> FaissDenseIndex:
"""Load the index from saved state.
Args:
name_or_path (str): Name or path of the index to load.
path (str): Name or path of the index to load.
repository (AbstractFileRepository): File repository to use for loading.
mmap (bool): Whether to use memory mapping. Defaults to False.
Expand All @@ -142,9 +139,7 @@ def _load(
Raises:
FileNotFoundError: If the index files are not found in the repository.
"""
state_file_path, index_file_path = BaseDenseIndex.build_index_file_paths(
name_or_path
)
state_file_path, index_file_path = BaseDenseIndex.build_index_file_paths(path)

if not repository.exists(state_file_path):
msg = f"Index.state {state_file_path} not found in repository."
Expand Down Expand Up @@ -243,14 +238,13 @@ def search_many(
"""
if not queries:
return []

if isinstance(queries[0], str):
queries = self._embed(queries, is_query=True, show_progress=show_progress)

if _support_nprobe(self.faiss_index) and n_probe is not None:
self.faiss_index.nprobe = n_probe

faiss.omp_set_num_threads(n_workers or self.n_workers)
n_workers = n_workers if n_workers is not None else self.n_workers
faiss.omp_set_num_threads(n_workers)

query_vectors = np.array(queries, dtype=np.float32)

Expand Down
20 changes: 11 additions & 9 deletions baguetter/indices/dense/usearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ class USearchDenseIndex(BaseDenseIndex, Index):
providing methods for adding, searching, and managing dense vector data.
"""

NAME_PREFIX = "usearch_"

def __init__(
self,
index_name: str = "new-index",
Expand Down Expand Up @@ -142,14 +140,14 @@ def from_config(cls, config: UsearchDenseIndexConfig) -> USearchDenseIndex:

def _save(
self,
path: str,
repository: AbstractFileRepository,
path: str | None = None,
) -> None:
) -> str:
"""Save the index state and data.
Args:
path (str): Path to save the index. If None, uses the index name.
repository (AbstractFileRepository): File repository to use for saving.
path (Optional[str]): Path to save the index. If None, uses the index name.
"""
state = {
Expand All @@ -173,18 +171,20 @@ def _save(
Index.save(self, temp_file.name)
file.write(temp_file.read())

return state_file_path

@classmethod
def _load(
cls,
name_or_path: str,
path: str,
*,
repository: AbstractFileRepository,
mmap: bool = False,
) -> USearchDenseIndex:
"""Load the index from saved state.
Args:
name_or_path (str): Name or path of the index to load.
path (str): Name or path of the index to load.
repository (AbstractFileRepository): File repository to use for loading.
mmap (bool): Whether to use memory mapping. Defaults to False.
Expand All @@ -198,7 +198,7 @@ def _load(
(
state_file_path,
index_file_path,
) = BaseDenseIndex.build_index_file_paths(name_or_path)
) = BaseDenseIndex.build_index_file_paths(path)

if not repository.exists(state_file_path):
msg = f"Index.state {state_file_path} not found in repository."
Expand Down Expand Up @@ -303,11 +303,13 @@ def search_many(
if not isinstance(queries, np.ndarray):
queries = np.array(queries)

n_workers = n_workers if n_workers is not None else self.n_workers

results = Index.search(
self,
vectors=queries,
count=top_k,
threads=n_workers or self.n_workers,
threads=n_workers,
log=show_progress,
radius=radius,
exact=self.config.exact_search if exact_search is None else exact_search,
Expand Down
29 changes: 14 additions & 15 deletions baguetter/indices/sparse/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
class BaseSparseIndex(BaseIndex, abc.ABC):
"""Base class for sparse indices. This class should not be used directly."""

NAME_PREFIX: str = "sparse_"

def __init__(
self,
index_name: str = "new-index",
Expand Down Expand Up @@ -91,7 +89,7 @@ def __init__(
self.index: object | None = None
self.key_mapping: dict[int, str] = {}
self.corpus_tokens: dict[str, list[str]] = {}
self.n_workers: int = n_workers or max(1, (os.cpu_count() or 1) - 1)
self.n_workers: int = n_workers if n_workers is not None else max(1, (os.cpu_count() or 1) - 1)

@abc.abstractmethod
def normalize_scores(self, n_tokens: int, scores: ndarray) -> ndarray:
Expand Down Expand Up @@ -149,7 +147,7 @@ def name(self) -> str:
str: The name of the index
"""
return f"{self.NAME_PREFIX}{self.config.index_name}"
return self.config.index_name

@property
def vocabulary(self) -> dict[str, int]:
Expand All @@ -163,14 +161,14 @@ def vocabulary(self) -> dict[str, int]:

def _save(
self,
path: str,
repository: AbstractFileRepository,
path: str | None = None,
) -> None:
) -> str:
"""Save the index to the given path.
Args:
path (str): Path to save the index to.
repository (AbstractFileRepository): File repository to save to.
path (str | None): Path to save the index to.
"""
state = {
Expand All @@ -179,22 +177,22 @@ def _save(
"corpus_tokens": self.corpus_tokens,
"config": dataclasses.asdict(self.config),
}
path = path or self.name
with repository.open(path, "wb") as f:
np.savez_compressed(f, state=state)
return path

@classmethod
def _load(
cls,
name_or_path: str,
*,
path: str,
repository: AbstractFileRepository,
*,
mmap: bool = False,
) -> BaseSparseIndex:
"""Load an index from the given path or name.
Args:
name_or_path (str): Name or path of the index.
path (str): Name or path of the index.
repository (AbstractFileRepository): File repository to load from.
mmap (bool): Whether to memory-map the file.
Expand All @@ -205,12 +203,12 @@ def _load(
FileNotFoundError: If the index file is not found.
"""
if not repository.exists(name_or_path):
msg = f"Index {name_or_path} not found."
if not repository.exists(path):
msg = f"Index {path} not found."
raise FileNotFoundError(msg)

mmap_mode = "r" if mmap else None
with repository.open(name_or_path, "rb") as f:
with repository.open(path, "rb") as f:
stored = np.load(f, allow_pickle=True, mmap_mode=mmap_mode)
state = stored["state"][()]
retriever = cls.from_config(SparseIndexConfig(**state["config"]))
Expand Down Expand Up @@ -474,8 +472,9 @@ def search_many(
if not queries:
return []

n_workers = n_workers or self.n_workers
n_workers = n_workers if n_workers is not None else self.n_workers
k_search = partial(self.search, top_k=top_k)

results = tqdm(
map_in_thread(
k_search,
Expand Down
2 changes: 0 additions & 2 deletions baguetter/indices/sparse/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class BM25SparseIndex(BaseSparseIndex):
for indexing and searching documents.
"""

NAME_PREFIX: str = "bm25_"

def normalize_scores(self, n_tokens: int, scores: ndarray) -> ndarray:
"""Normalize BM25 scores by the number of tokens in the query.
Expand Down
2 changes: 0 additions & 2 deletions baguetter/indices/sparse/bmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ class BMXSparseIndex(BaseSparseIndex):
for indexing and searching documents.
"""

NAME_PREFIX: str = "bmx_"

def normalize_scores(self, n_tokens: int, scores: ndarray) -> ndarray:
"""Normalize BMX scores by the number of tokens in the query.
Expand Down
12 changes: 6 additions & 6 deletions baguetter/indices/sparse/models/bmx/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from collections import defaultdict
from typing import TYPE_CHECKING

import numpy as np
import numba as nb
import numpy as np
from numpy import ndarray

from baguetter.utils.common import tqdm
Expand Down Expand Up @@ -218,31 +218,31 @@ def build_index(
)

# [doc_count x n_terms]
df_matrix = vectorizer.fit_transform(
dt_matrix = vectorizer.fit_transform(
tqdm(
corpus_tokens,
total=n_docs,
disable=not show_progress,
desc="Building TDF matrix",
desc="Building doc-term matrix",
dynamic_ncols=True,
mininterval=0.5,
),
)

# [n_terms x doc_count]
df_matrix = df_matrix.transpose().tocsr()
dt_matrix = dt_matrix.transpose().tocsr()

unique_tokens = vectorizer.get_feature_names_out()
unique_token_ids = np.arange(len(unique_tokens))

inverted_index = convert_df_matrix_into_inverted_index(
df_matrix=df_matrix,
df_matrix=dt_matrix,
unique_token_ids=unique_token_ids,
n_docs=n_docs,
int_dtype=int_dtype,
show_progress=show_progress,
)
doc_lens = np.squeeze(np.asarray(df_matrix.sum(axis=0), dtype=dtype))
doc_lens = np.squeeze(np.asarray(dt_matrix.sum(axis=0), dtype=dtype))
avg_doc_len = float(np.mean(doc_lens))
relative_doc_lens = doc_lens / avg_doc_len

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def process_many(
Returns:
Generator[list[str], None, None] | list[list[str]]: Processed text items as lists of tokens.
"""
if n_workers <= 1:
if n_workers <= 0:
processor = map(self._call_steps, items)
else:
processor = map_in_process(
Expand Down
2 changes: 1 addition & 1 deletion baguetter/utils/file_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, path: str | None = None, **kwargs) -> None:
"""
super().__init__(**kwargs)
self._base_path = str(Path(path or settings.base_path).resolve())
self._base_path = str(Path(path or f"{settings.base_path}/repository").resolve())
if not self.isdir(self._base_path):
if self.exists(self._base_path):
msg = f"Path '{self._base_path}' exists but is not a directory."
Expand Down
Loading

0 comments on commit de1de89

Please sign in to comment.