From a261101660183746ca6a95733819cc6d90ab5148 Mon Sep 17 00:00:00 2001 From: Julius Lipp Date: Tue, 13 Aug 2024 13:49:05 -0700 Subject: [PATCH 1/2] Enhance loading + add notebooks --- baguetter/indices/dense/base.py | 16 +- baguetter/indices/dense/faiss.py | 19 +- baguetter/indices/dense/usearch.py | 18 +- baguetter/indices/sparse/base.py | 29 ++- baguetter/indices/sparse/bm25.py | 2 - baguetter/indices/sparse/bmx.py | 2 - baguetter/indices/sparse/models/bmx/index.py | 10 +- baguetter/utils/file_repository.py | 2 +- baguetter/utils/persistable.py | 49 +++-- examples/hf_save_idx.ipynb | 195 +++++++++++++++++++ examples/save_idx.ipynb | 188 ++++++++++++++++++ 11 files changed, 458 insertions(+), 72 deletions(-) create mode 100644 examples/hf_save_idx.ipynb create mode 100644 examples/save_idx.ipynb diff --git a/baguetter/indices/dense/base.py b/baguetter/indices/dense/base.py index d441f91..e07cb14 100644 --- a/baguetter/indices/dense/base.py +++ b/baguetter/indices/dense/base.py @@ -12,13 +12,11 @@ import numpy as np -_INDEX_PREFIX = "usearch_index_" -_STATE_PREFIX = "usearch_state_" +_STATE_PREFIX = "state_" +_INDEX_PREFIX = "index_" class BaseDenseIndex(BaseIndex, abc.ABC): - NAME_PREFIX: str = "dense_" - def __init__( self, index_name: str = "new-index", @@ -29,13 +27,13 @@ def __init__( super().__init__() self.index_name: str = index_name - self.n_workers: int = n_workers or max(1, (os.cpu_count() or 1) - 1) + self.n_workers: int = n_workers if n_workers is not None else max(1, (os.cpu_count() or 1) - 1) self._embed_fn: Callable[[list[str], bool], np.ndarray] | None = embed_fn @property def name(self) -> str: """Get the full name of the index.""" - return f"{self.NAME_PREFIX}{self.index_name}" + return self.index_name def _embed(self, query: list[str], *, is_query: bool = False, show_progress: bool = False) -> np.ndarray: """Embed text queries into vectors. @@ -61,17 +59,17 @@ def _embed(self, query: list[str], *, is_query: bool = False, show_progress: boo return self._embed_fn(query, is_query=is_query, show_progress=show_progress) @staticmethod - def build_index_file_paths(name_or_path: str) -> tuple[str, str]: + def build_index_file_paths(path: str) -> tuple[str, str]: """Build the file paths for the index and state files. Args: - name_or_path (str): Path to the index. + path (str): Path to the index. Returns: Tuple[str, str]: File paths for the state and index files. """ - path = Path(name_or_path) + path = Path(path) state_file_name = f"{_STATE_PREFIX}{path.name}" index_file_name = f"{_INDEX_PREFIX}{path.name}" diff --git a/baguetter/indices/dense/faiss.py b/baguetter/indices/dense/faiss.py index c094bed..58cd664 100644 --- a/baguetter/indices/dense/faiss.py +++ b/baguetter/indices/dense/faiss.py @@ -11,7 +11,6 @@ from baguetter.indices.dense.base import BaseDenseIndex from baguetter.indices.dense.config import FaissDenseIndexConfig from baguetter.logger import LOGGER -from baguetter.utils.common import batch_iter if TYPE_CHECKING: from collections.abc import Callable @@ -28,8 +27,6 @@ def _support_nprobe(index: Any) -> bool: class FaissDenseIndex(BaseDenseIndex): """A dense index implementation using Faiss.""" - NAME_PREFIX = "faiss_" - def __init__( self, index_name: str = "new-index", @@ -96,9 +93,9 @@ def from_config(cls, config: FaissDenseIndexConfig) -> FaissDenseIndex: def _save( self, + path: str, repository: AbstractFileRepository, - path: str | None = None, - ) -> None: + ) -> str: """Save the index state and data.""" state = { "key_mapping": self.key_mapping, @@ -116,11 +113,12 @@ def _save( with repository.open(index_file_path, "wb") as file: index = faiss.serialize_index(self.faiss_index) np.savez_compressed(file, index=index) + return path @classmethod def _load( cls, - name_or_path: str, + path: str, *, repository: AbstractFileRepository, mmap: bool = False, @@ -128,7 +126,7 @@ def _load( """Load the index from saved state. Args: - name_or_path (str): Name or path of the index to load. + path (str): Name or path of the index to load. repository (AbstractFileRepository): File repository to use for loading. mmap (bool): Whether to use memory mapping. Defaults to False. @@ -138,7 +136,7 @@ def _load( Raises: FileNotFoundError: If the index files are not found in the repository. """ - state_file_path, index_file_path = BaseDenseIndex.build_index_file_paths(name_or_path) + state_file_path, index_file_path = BaseDenseIndex.build_index_file_paths(path) if not repository.exists(state_file_path): msg = f"Index.state {state_file_path} not found in repository." @@ -233,14 +231,13 @@ def search_many( """ if not queries: return [] - if isinstance(queries[0], str): queries = self._embed(queries, is_query=True, show_progress=show_progress) - if _support_nprobe(self.faiss_index) and n_probe is not None: self.faiss_index.nprobe = n_probe - faiss.omp_set_num_threads(n_workers or self.n_workers) + n_workers = n_workers if n_workers is not None else self.n_workers + faiss.omp_set_num_threads(n_workers) query_vectors = np.array(queries, dtype=np.float32) diff --git a/baguetter/indices/dense/usearch.py b/baguetter/indices/dense/usearch.py index 7c3af32..8f6b9b1 100644 --- a/baguetter/indices/dense/usearch.py +++ b/baguetter/indices/dense/usearch.py @@ -51,8 +51,6 @@ class USearchDenseIndex(BaseDenseIndex, Index): providing methods for adding, searching, and managing dense vector data. """ - NAME_PREFIX = "usearch_" - def __init__( self, index_name: str = "new-index", @@ -142,14 +140,14 @@ def from_config(cls, config: UsearchDenseIndexConfig) -> USearchDenseIndex: def _save( self, + path: str, repository: AbstractFileRepository, - path: str | None = None, ) -> None: """Save the index state and data. Args: + path (str): Path to save the index. If None, uses the index name. repository (AbstractFileRepository): File repository to use for saving. - path (Optional[str]): Path to save the index. If None, uses the index name. """ state = { @@ -173,10 +171,12 @@ def _save( Index.save(self, temp_file.name) file.write(temp_file.read()) + return path + @classmethod def _load( cls, - name_or_path: str, + path: str, *, repository: AbstractFileRepository, mmap: bool = False, @@ -184,7 +184,7 @@ def _load( """Load the index from saved state. Args: - name_or_path (str): Name or path of the index to load. + path (str): Name or path of the index to load. repository (AbstractFileRepository): File repository to use for loading. mmap (bool): Whether to use memory mapping. Defaults to False. @@ -198,7 +198,7 @@ def _load( ( state_file_path, index_file_path, - ) = BaseDenseIndex.build_index_file_paths(name_or_path) + ) = BaseDenseIndex.build_index_file_paths(path) if not repository.exists(state_file_path): msg = f"Index.state {state_file_path} not found in repository." @@ -303,11 +303,13 @@ def search_many( if not isinstance(queries, np.ndarray): queries = np.array(queries) + n_workers = n_workers if n_workers is not None else self.n_workers + results = Index.search( self, vectors=queries, count=top_k, - threads=n_workers or self.n_workers, + threads=n_workers, log=show_progress, radius=radius, exact=self.config.exact_search if exact_search is None else exact_search, diff --git a/baguetter/indices/sparse/base.py b/baguetter/indices/sparse/base.py index 9c12925..f5aee23 100644 --- a/baguetter/indices/sparse/base.py +++ b/baguetter/indices/sparse/base.py @@ -25,8 +25,6 @@ class BaseSparseIndex(BaseIndex, abc.ABC): """Base class for sparse indices. This class should not be used directly.""" - NAME_PREFIX: str = "sparse_" - def __init__( self, index_name: str = "new-index", @@ -91,7 +89,7 @@ def __init__( self.index: object | None = None self.key_mapping: dict[int, str] = {} self.corpus_tokens: dict[str, list[str]] = {} - self.n_workers: int = n_workers or max(1, (os.cpu_count() or 1) - 1) + self.n_workers: int = n_workers if n_workers is not None else max(1, (os.cpu_count() or 1) - 1) @abc.abstractmethod def normalize_scores(self, n_tokens: int, scores: ndarray) -> ndarray: @@ -149,7 +147,7 @@ def name(self) -> str: str: The name of the index """ - return f"{self.NAME_PREFIX}{self.config.index_name}" + return self.config.index_name @property def vocabulary(self) -> dict[str, int]: @@ -163,14 +161,14 @@ def vocabulary(self) -> dict[str, int]: def _save( self, + path: str, repository: AbstractFileRepository, - path: str | None = None, - ) -> None: + ) -> str: """Save the index to the given path. Args: + path (str): Path to save the index to. repository (AbstractFileRepository): File repository to save to. - path (str | None): Path to save the index to. """ state = { @@ -179,22 +177,22 @@ def _save( "corpus_tokens": self.corpus_tokens, "config": dataclasses.asdict(self.config), } - path = path or self.name with repository.open(path, "wb") as f: np.savez_compressed(f, state=state) + return path @classmethod def _load( cls, - name_or_path: str, - *, + path: str, repository: AbstractFileRepository, + *, mmap: bool = False, ) -> BaseSparseIndex: """Load an index from the given path or name. Args: - name_or_path (str): Name or path of the index. + path (str): Name or path of the index. repository (AbstractFileRepository): File repository to load from. mmap (bool): Whether to memory-map the file. @@ -205,12 +203,12 @@ def _load( FileNotFoundError: If the index file is not found. """ - if not repository.exists(name_or_path): - msg = f"Index {name_or_path} not found." + if not repository.exists(path): + msg = f"Index {path} not found." raise FileNotFoundError(msg) mmap_mode = "r" if mmap else None - with repository.open(name_or_path, "rb") as f: + with repository.open(path, "rb") as f: stored = np.load(f, allow_pickle=True, mmap_mode=mmap_mode) state = stored["state"][()] retriever = cls.from_config(SparseIndexConfig(**state["config"])) @@ -474,8 +472,9 @@ def search_many( if not queries: return [] - n_workers = n_workers or self.n_workers + n_workers = n_workers if n_workers is not None else self.n_workers k_search = partial(self.search, top_k=top_k) + results = tqdm( map_in_thread( k_search, diff --git a/baguetter/indices/sparse/bm25.py b/baguetter/indices/sparse/bm25.py index 40a7262..1145460 100644 --- a/baguetter/indices/sparse/bm25.py +++ b/baguetter/indices/sparse/bm25.py @@ -27,8 +27,6 @@ class BM25SparseIndex(BaseSparseIndex): for indexing and searching documents. """ - NAME_PREFIX: str = "bm25_" - def normalize_scores(self, n_tokens: int, scores: ndarray) -> ndarray: """Normalize BM25 scores by the number of tokens in the query. diff --git a/baguetter/indices/sparse/bmx.py b/baguetter/indices/sparse/bmx.py index 362c32f..9cc6c33 100644 --- a/baguetter/indices/sparse/bmx.py +++ b/baguetter/indices/sparse/bmx.py @@ -14,8 +14,6 @@ class BMXSparseIndex(BaseSparseIndex): for indexing and searching documents. """ - NAME_PREFIX: str = "bmx_" - def normalize_scores(self, n_tokens: int, scores: ndarray) -> ndarray: """Normalize BMX scores by the number of tokens in the query. diff --git a/baguetter/indices/sparse/models/bmx/index.py b/baguetter/indices/sparse/models/bmx/index.py index bceedb0..0522421 100644 --- a/baguetter/indices/sparse/models/bmx/index.py +++ b/baguetter/indices/sparse/models/bmx/index.py @@ -218,31 +218,31 @@ def build_index( ) # [doc_count x n_terms] - df_matrix = vectorizer.fit_transform( + dt_matrix = vectorizer.fit_transform( tqdm( corpus_tokens, total=n_docs, disable=not show_progress, - desc="Building TDF matrix", + desc="Building doc-term matrix", dynamic_ncols=True, mininterval=0.5, ), ) # [n_terms x doc_count] - df_matrix = df_matrix.transpose().tocsr() + dt_matrix = dt_matrix.transpose().tocsr() unique_tokens = vectorizer.get_feature_names_out() unique_token_ids = np.arange(len(unique_tokens)) inverted_index = convert_df_matrix_into_inverted_index( - df_matrix=df_matrix, + df_matrix=dt_matrix, unique_token_ids=unique_token_ids, n_docs=n_docs, int_dtype=int_dtype, show_progress=show_progress, ) - doc_lens = np.squeeze(np.asarray(df_matrix.sum(axis=0), dtype=dtype)) + doc_lens = np.squeeze(np.asarray(dt_matrix.sum(axis=0), dtype=dtype)) avg_doc_len = float(np.mean(doc_lens)) relative_doc_lens = doc_lens / avg_doc_len diff --git a/baguetter/utils/file_repository.py b/baguetter/utils/file_repository.py index 208aad7..228ace9 100644 --- a/baguetter/utils/file_repository.py +++ b/baguetter/utils/file_repository.py @@ -103,7 +103,7 @@ def __init__(self, path: str | None = None, **kwargs) -> None: """ super().__init__(**kwargs) - self._base_path = str(Path(path or settings.base_path).resolve()) + self._base_path = str(Path(path or f"{settings.base_path}/repository").resolve()) if not self.isdir(self._base_path): if self.exists(self._base_path): msg = f"Path '{self._base_path}' exists but is not a directory." diff --git a/baguetter/utils/persistable.py b/baguetter/utils/persistable.py index 38133fd..db8b41c 100644 --- a/baguetter/utils/persistable.py +++ b/baguetter/utils/persistable.py @@ -17,7 +17,7 @@ class Persistable(ABC): @abstractmethod def _load( cls, - name_or_path: str, + path: str, repository: AbstractFileRepository, *, allow_pickle: bool = True, @@ -26,7 +26,7 @@ def _load( """Load an object from storage. Args: - name_or_path (str): Name or path of the object to load. + path (str): Path of the object to load. repository (AbstractFileRepository): File repository to load from. allow_pickle (bool, optional): Whether to allow loading pickled objects. Defaults to True. mmap (bool, optional): Whether to memory-map the file. Defaults to False. @@ -37,34 +37,41 @@ def _load( """ @abstractmethod - def _save(self, repository: AbstractFileRepository, path: str | None) -> None: + def _save(self, path: str, repository: AbstractFileRepository) -> str: """Save the object to storage. Args: + path (str): Path to save the object to. repository (AbstractFileRepository): File repository to save to. - path (str | None): Path to save the object to. + + Returns: + str: Path to the saved object. """ - def save(self, path: str | None = None) -> None: + def save(self, path: str) -> str: """Save the object to a local file repository. Args: path (str, optional): Path to save the object to. + Returns: + str: Path to the saved object. + """ repository = LocalFileRepository() - if path: - directory = path.rsplit("/", 1)[0] - repository.mkdirs(directory, exist_ok=True) - self._save(repository=repository, path=path) + directory = path.rsplit("/", 1) + if len(directory) > 1: + repository.mkdirs(directory[0], exist_ok=True) + path = self._save(path=path, repository=repository) + return repository.info(path)["name"] @classmethod - def load(cls, name_or_path: str, *, mmap: bool = False) -> Any: + def load(cls, path: str, *, mmap: bool = False) -> Any: """Load an object from a local file repository. Args: - name_or_path (str): Name or path of the object to load. + path (str): Path of the object to load. mmap (bool, optional): Whether to memory-map the file. Defaults to False. Returns: @@ -73,7 +80,7 @@ def load(cls, name_or_path: str, *, mmap: bool = False) -> Any: """ repository = LocalFileRepository() return cls._load( - name_or_path=name_or_path, + path=path, repository=repository, mmap=mmap, ) @@ -90,7 +97,7 @@ class HuggingFacePersistable(Persistable, ABC): def load_from_hub( cls, repo_id: str, - name_or_path: str, + path: str, *, repo_type: str | None = None, token: str | None = None, @@ -101,7 +108,7 @@ def load_from_hub( Args: repo_id (str): Repository ID. - name_or_path (str): Name or path of the object. + path (str): Path of the object. repo_type (str, optional): Repository type. Defaults to None. token (str, optional): Hugging Face API token. Defaults to None. mmap (bool, optional): Whether to memory-map the file. Defaults to False. @@ -119,28 +126,31 @@ def load_from_hub( **kwargs, ) - return cls._load(name_or_path, repository, mmap=mmap) + return cls._load(path=path, repository=repository, mmap=mmap) def push_to_hub( self, repo_id: str, + path: str, *, - path_in_repo: str | None = None, private: bool = True, repo_type: str | None = None, token: str | None = None, **kwargs, - ) -> None: + ) -> str: """Save an object to the Hugging Face Hub. Args: repo_id (str): Repository ID. - path_in_repo (str, optional): Custom path within the repository. Defaults to None. + path (str): Path of the object. private (bool, optional): Whether the repository is private. Defaults to True. repo_type (str, optional): Repository type. Defaults to None. token (str, optional): Hugging Face API token. Defaults to None. **kwargs: Additional arguments for HuggingFaceFileRepository. + Returns: + str: Path to the saved object. + """ repository = HuggingFaceFileRepository( repo_id=repo_id, @@ -151,4 +161,5 @@ def push_to_hub( **kwargs, ) - self._save(repository, path_in_repo) + path = self._save(path=path, repository=repository) + return repository.info(path)["name"] diff --git a/examples/hf_save_idx.ipynb b/examples/hf_save_idx.ipynb new file mode 100644 index 0000000..ccf4ea0 --- /dev/null +++ b/examples/hf_save_idx.ipynb @@ -0,0 +1,195 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Uploading indices to Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from baguetter.indices import BMXSparseIndex\n", + "from baguetter.evaluation import HFDataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Create index and load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "index = BMXSparseIndex()\n", + "\n", + "ds = HFDataset(\"mteb/scidocs\", \"corpus\")\n", + "doc_ids, docs = ds.get_corpus()\n", + "_, queries = ds.get_queries()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Add documents to index" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tokenization: 100%|██████████| 25657/25657 [00:04<00:00, 5282.20it/s]\n", + "Building doc-term matrix: 100%|██████████| 25657/25657 [00:00<00:00, 62623.85it/s]\n", + "Building inverted index: 100%|██████████| 61627/61627 [00:04<00:00, 14888.04it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.add_many(doc_ids, docs, show_progress=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Save index to Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "No files have been modified since last commit. Skipping to prevent empty commit.\n" + ] + }, + { + "data": { + "text/plain": [ + "'datasets/mixedbread-ai/baguetter/bmx_scidocs'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.push_to_hub(\"mixedbread-ai/baguetter\", \"bmx_scidocs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Load index from Hugging Face" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "idx = index.load_from_hub(\"mixedbread-ai/baguetter\", \"bmx_scidocs\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Use index" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "SearchResults(keys=['86e87db2dab958f1bd5877dc7d5b8105d6e31e46', 'cd31ecb3b58d1ec0d8b6e196bddb71dd6a921b6d', '2a43d3905699927ace64e880fe9ba8a730e14be1', 'eef39364df06eb9933d2fc41a0f13eea17113c58', '19c90b3c0c0d94e8235731a057cc6377c46482ee', '768b18d745639fcfb157fe16cbd957ca60ebfc2e', 'f2ab0a2aa4177dd267c3c6cc37c7ad0e33c2cdbf', 'd504a72e40ecee5c2e721629e7368a959b18c681', 'd1d120bc98e536dd33e37c876aaba57e584d252e', 'e2890afe42e64b910609e7554130d6a81427e02a', '829033fd070c6ed30d28a21187e0def25a3e809f', '0948365ef39ef153e61e9569ade541cf881c7c2a', '4a4cea4421ff0be7bcc06e92179cd2d5f1102ff8', '745b88eb437eb59e2a58fe378d287702a6b0d985', '1f009366a901c403a8aad65c94ec2fecf3428081', '26880494f79ae1e35ffee7f055cb0ad5693060c2', '432143ab67c05f42c918c4ed6fd9412d26e659be', '53f3edfeb22de82c7a4b4a02209d296526eee38c', 'a16dc6af67ef9746068c63a56a580cb3b2a83e9c', '2eafdb47aa9b5b510f7fcb113b22e6ab7c79d143', '0a202f1dfc6991a6a204eaa5e6b46d6223a4d98a', '6307f94aefdc7268c27e3af8fc04f090bc1b18bb', 'e90dd4a2750df4d52918a610ba9fb2b013153508', '1e7efea26cfbbcd2905d63451e77a02f1031ea12', '3978e9f794174c7a2700b20193c071a7b1532b22', '1e8f46aeed1a96554a2d759d7ca194e1f9c22de1', '41d205fd36883f506bccf56db442bac92a854ec3', '2766913aabb151107b28279645b915a3aa86c816', '120f1a81fd4abd089f47a335d0771b4162e851e8', '91de962e115bcf65eaf8579471a818ba8c5b0ea6', '7a58abc92dbe41c9e5b3c7b0a358ab9096880f25', '2fdee22266d58ae4e891711208106ca46c8e2778', 'd5bed3840cbb7e29a89847c21b697609607d12d2', '943d17f36d320ad9fcc3ae82c78914c0111cef1d', 'd79dd895912a36670b3477645f361e2fdd73185b', '02599a02d46ea2f1c00e14cac2a76dcb156df8ee', '3c55c334d34b611a565683ea42a06d4e1f01db47', '9283e274236af381cfb20e7dda79f249936b02ab', '3fd46ca896d023df8c8af2b3951730d8c38defdd', '0760b3baa196cd449d2c81604883815a6fc73b6a', '576296d739f61fcd4f4433f09a91f350a0c9598d', 'fda1e13a2eaeaa0b4434833d3ee0eb8e79b0ba94', 'ea5b1f3c719cd4ddd4c78a0da1501e36d87d9782', '4ac639f092b870ebe72e4b366afb90f3073d6223', 'ab93fe25985409d99b15a49ae9d8987561749a32', 'f3381a72a5ed288d54a93d92a85e96f7ba2ab36c', '62c3daf1899f6841f7092961193f062cc4fe1103', 'bfb88f34328be56dc7917a59c2aee7a8c22795e1', '166f42f66c5e6dd959548acfb97dc77a36013639', 'ad0323b075146e7d7a3ef3aacb9892201da69492', '84ca84dad742749a827291e103cde8185cea1bcf', '9c8e7655dd233df3d5b3249f416ec992cebe7a10', '3fe910b1360a77f50f73c2e82e654b6028072826', '00b202871ec41b8049e8393e463660525ecb61b5', '55ca165fa6091973674b12ea8fa3f1a3a1e50a6d', '51bb6450e617986d1bd8566878f7693ffd03132d', '9981e27f01960526ea68227c7f8120e0c3ffe87f', '5092a67406d823a6f6fd3dac555b9d022ad20bdf', '2ec3a0d6c71face777138f7cdc2e44d6762d23f5', '77ccf604ca460ac65d2bd14792c901879c4a0153', 'dca4eaacddb18ad44786c008b73296831502d27c', '35875600a30f89ea133ac06afeefc8cacec9fb3d', 'c1e9c4c5637c2d67863ee53eef3aa2df20a6e56d', 'b323c4d8f284dd27b9bc8c8be5bee3cd30e2c8ca', 'a60791316f5d749d9248c755112653bd527db2fe', '03f98c175b4230960ac347b1100fbfc10c100d0c', '6b6fa87688f1e0ddb676a9ce5d18a7185f98d0c5', '61736617ae1eb5483a3b8b182815ab6c59bf4939', '9b1b350dc58def7b7d7b147b779aa0b534b5b335', '35e846afa7e247ed7ff5acc2448d4e766d9183dc', 'eee9d92794872fd3eecf38f86cd26d605f3eede7', '0b584f82ec87f068416c553b8c94778eecf9f7d6', '58ca5ac14af2765ce1d25c3a82d6f9312437ded0', '3c9ac1876a69b4e35b5f0690ea817de6ac26295d', 'd641503d4551dc3a3f9eabefd27045996ed16887', '3baddc440617ce202fd190b32b1d73f1bb14561d', 'b6b53d8c8790d668e799802444e31e90ac177479', '2cf9714cb82974c85c99a5f3bfe5cd79de52bd69', '26da3190bbe181dac7a0ced5cef7745358a5346c', '2d93e7af2e38d9479114006704b836533026279f', '1a3470626b24ccd510047925f80d21affde3c3b8', '8a7acaf6469c06ae5876d92f013184db5897bb13', 'dc53c638f58bf3982c5a6ed82002d56c955763c2', '222d8b2803f9cedf0da0b454c061c0bb46384722', '450d6ef1acfe802ae0cfeca71a8b355d103b2865', 'af777f8b1c694e353a57d81c3c1b4620e2ae61b1', '2d6d056ca33bb20e7bec33b49093cc4a907bf1a0', 'dd0b5dd2d15ebc6a5658c75ec102b64e359c674d', 'c1f8a3a1b4df9b7856d4fbcfa91ef2752bcc7070', '35a9c2fad935a2389a7b6e3a53d88ea476db611e', '071a6cd442706e424ea09bc8852eaa2e901c72f3', 'b743dafa3dcb8924244c14f0a719cde5e93d9155', '831d2fa6af688ef2d6b754bb315ef6cb20085763', '6a640438a4e50fa31943462eeca716413891a773', 'e9b7367c63ba970cc9a0360116b160dbe1eb1bb4', '755050d838b9b27d715c4bf1e8317294011fa5fc', '8663945d5090fe409e42af217ac19f77f69eee28', '1b9de2d1e74fbe49bf852fa495f63c31bb038a31', '595d0fe1c259c02069075d8c687210211908c3ed', 'abc7254b751b124ff98cbf522526cf2ce5376e95'], scores=array([25.882957 , 23.048092 , 21.487345 , 15.326306 , 15.008339 ,\n", + " 14.661501 , 14.5524025, 14.49628 , 14.228694 , 13.837463 ,\n", + " 13.796093 , 13.57072 , 13.536688 , 13.438 , 13.33793 ,\n", + " 13.2009535, 13.157523 , 13.150376 , 12.943943 , 12.894042 ,\n", + " 12.818537 , 12.801729 , 12.74225 , 12.502438 , 12.294656 ,\n", + " 12.225283 , 12.219966 , 12.075517 , 11.996811 , 11.959464 ,\n", + " 11.944234 , 11.833931 , 11.832078 , 11.817529 , 11.794199 ,\n", + " 11.759341 , 11.749658 , 11.721568 , 11.675982 , 11.661502 ,\n", + " 11.644913 , 11.623637 , 11.617721 , 11.418661 , 11.39984 ,\n", + " 11.393229 , 11.373691 , 11.372035 , 11.354653 , 11.342207 ,\n", + " 11.271814 , 11.204544 , 11.131927 , 11.116472 , 11.089201 ,\n", + " 11.083519 , 11.082279 , 11.05196 , 10.975056 , 10.940313 ,\n", + " 10.93409 , 10.8945465, 10.879114 , 10.864438 , 10.847451 ,\n", + " 10.832269 , 10.772312 , 10.77147 , 10.71236 , 10.706525 ,\n", + " 10.698383 , 10.68005 , 10.6775875, 10.643851 , 10.63102 ,\n", + " 10.624665 , 10.601247 , 10.597593 , 10.581311 , 10.578881 ,\n", + " 10.537812 , 10.488211 , 10.486176 , 10.47819 , 10.457276 ,\n", + " 10.447503 , 10.400571 , 10.382628 , 10.371334 , 10.359387 ,\n", + " 10.354765 , 10.342985 , 10.341277 , 10.3036785, 10.256121 ,\n", + " 10.221265 , 10.217081 , 10.216501 , 10.209621 , 10.163714 ],\n", + " dtype=float32), normalized=False)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idx.search(queries[0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/save_idx.ipynb b/examples/save_idx.ipynb new file mode 100644 index 0000000..fb94d32 --- /dev/null +++ b/examples/save_idx.ipynb @@ -0,0 +1,188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Storing index locally" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from baguetter.indices import BMXSparseIndex\n", + "from baguetter.evaluation import HFDataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Create index and load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "index = BMXSparseIndex()\n", + "\n", + "ds = HFDataset(\"mteb/scidocs\", \"corpus\")\n", + "doc_ids, docs = ds.get_corpus()\n", + "_, queries = ds.get_queries()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Add documents to index" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Tokenization: 100%|██████████| 25657/25657 [00:04<00:00, 5193.93it/s]\n", + "Building doc-term matrix: 100%|██████████| 25657/25657 [00:00<00:00, 54218.01it/s]\n", + "Building inverted index: 100%|██████████| 61627/61627 [00:04<00:00, 14276.33it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.add_many(doc_ids, docs, show_progress=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Save index" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/Users/juliuslipp/.cache/baguetter/repository/super-cool-idx'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.save(\"super-cool-idx\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Load index" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "idx = index.load(\"super-cool-idx\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Use index" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "SearchResults(keys=['86e87db2dab958f1bd5877dc7d5b8105d6e31e46', 'cd31ecb3b58d1ec0d8b6e196bddb71dd6a921b6d', '2a43d3905699927ace64e880fe9ba8a730e14be1', 'eef39364df06eb9933d2fc41a0f13eea17113c58', '19c90b3c0c0d94e8235731a057cc6377c46482ee', '768b18d745639fcfb157fe16cbd957ca60ebfc2e', 'f2ab0a2aa4177dd267c3c6cc37c7ad0e33c2cdbf', 'd504a72e40ecee5c2e721629e7368a959b18c681', 'd1d120bc98e536dd33e37c876aaba57e584d252e', 'e2890afe42e64b910609e7554130d6a81427e02a', '829033fd070c6ed30d28a21187e0def25a3e809f', '0948365ef39ef153e61e9569ade541cf881c7c2a', '4a4cea4421ff0be7bcc06e92179cd2d5f1102ff8', '745b88eb437eb59e2a58fe378d287702a6b0d985', '1f009366a901c403a8aad65c94ec2fecf3428081', '26880494f79ae1e35ffee7f055cb0ad5693060c2', '432143ab67c05f42c918c4ed6fd9412d26e659be', '53f3edfeb22de82c7a4b4a02209d296526eee38c', 'a16dc6af67ef9746068c63a56a580cb3b2a83e9c', '2eafdb47aa9b5b510f7fcb113b22e6ab7c79d143', '0a202f1dfc6991a6a204eaa5e6b46d6223a4d98a', '6307f94aefdc7268c27e3af8fc04f090bc1b18bb', 'e90dd4a2750df4d52918a610ba9fb2b013153508', '1e7efea26cfbbcd2905d63451e77a02f1031ea12', '3978e9f794174c7a2700b20193c071a7b1532b22', '1e8f46aeed1a96554a2d759d7ca194e1f9c22de1', '41d205fd36883f506bccf56db442bac92a854ec3', '2766913aabb151107b28279645b915a3aa86c816', '120f1a81fd4abd089f47a335d0771b4162e851e8', '91de962e115bcf65eaf8579471a818ba8c5b0ea6', '7a58abc92dbe41c9e5b3c7b0a358ab9096880f25', '2fdee22266d58ae4e891711208106ca46c8e2778', 'd5bed3840cbb7e29a89847c21b697609607d12d2', '943d17f36d320ad9fcc3ae82c78914c0111cef1d', 'd79dd895912a36670b3477645f361e2fdd73185b', '02599a02d46ea2f1c00e14cac2a76dcb156df8ee', '3c55c334d34b611a565683ea42a06d4e1f01db47', '9283e274236af381cfb20e7dda79f249936b02ab', '3fd46ca896d023df8c8af2b3951730d8c38defdd', '0760b3baa196cd449d2c81604883815a6fc73b6a', '576296d739f61fcd4f4433f09a91f350a0c9598d', 'fda1e13a2eaeaa0b4434833d3ee0eb8e79b0ba94', 'ea5b1f3c719cd4ddd4c78a0da1501e36d87d9782', '4ac639f092b870ebe72e4b366afb90f3073d6223', 'ab93fe25985409d99b15a49ae9d8987561749a32', 'f3381a72a5ed288d54a93d92a85e96f7ba2ab36c', '62c3daf1899f6841f7092961193f062cc4fe1103', 'bfb88f34328be56dc7917a59c2aee7a8c22795e1', '166f42f66c5e6dd959548acfb97dc77a36013639', 'ad0323b075146e7d7a3ef3aacb9892201da69492', '84ca84dad742749a827291e103cde8185cea1bcf', '9c8e7655dd233df3d5b3249f416ec992cebe7a10', '3fe910b1360a77f50f73c2e82e654b6028072826', '00b202871ec41b8049e8393e463660525ecb61b5', '55ca165fa6091973674b12ea8fa3f1a3a1e50a6d', '51bb6450e617986d1bd8566878f7693ffd03132d', '9981e27f01960526ea68227c7f8120e0c3ffe87f', '5092a67406d823a6f6fd3dac555b9d022ad20bdf', '2ec3a0d6c71face777138f7cdc2e44d6762d23f5', '77ccf604ca460ac65d2bd14792c901879c4a0153', 'dca4eaacddb18ad44786c008b73296831502d27c', '35875600a30f89ea133ac06afeefc8cacec9fb3d', 'c1e9c4c5637c2d67863ee53eef3aa2df20a6e56d', 'b323c4d8f284dd27b9bc8c8be5bee3cd30e2c8ca', 'a60791316f5d749d9248c755112653bd527db2fe', '03f98c175b4230960ac347b1100fbfc10c100d0c', '6b6fa87688f1e0ddb676a9ce5d18a7185f98d0c5', '61736617ae1eb5483a3b8b182815ab6c59bf4939', '9b1b350dc58def7b7d7b147b779aa0b534b5b335', '35e846afa7e247ed7ff5acc2448d4e766d9183dc', 'eee9d92794872fd3eecf38f86cd26d605f3eede7', '0b584f82ec87f068416c553b8c94778eecf9f7d6', '58ca5ac14af2765ce1d25c3a82d6f9312437ded0', '3c9ac1876a69b4e35b5f0690ea817de6ac26295d', 'd641503d4551dc3a3f9eabefd27045996ed16887', '3baddc440617ce202fd190b32b1d73f1bb14561d', 'b6b53d8c8790d668e799802444e31e90ac177479', '2cf9714cb82974c85c99a5f3bfe5cd79de52bd69', '26da3190bbe181dac7a0ced5cef7745358a5346c', '2d93e7af2e38d9479114006704b836533026279f', '1a3470626b24ccd510047925f80d21affde3c3b8', '8a7acaf6469c06ae5876d92f013184db5897bb13', 'dc53c638f58bf3982c5a6ed82002d56c955763c2', '222d8b2803f9cedf0da0b454c061c0bb46384722', '450d6ef1acfe802ae0cfeca71a8b355d103b2865', 'af777f8b1c694e353a57d81c3c1b4620e2ae61b1', '2d6d056ca33bb20e7bec33b49093cc4a907bf1a0', 'dd0b5dd2d15ebc6a5658c75ec102b64e359c674d', 'c1f8a3a1b4df9b7856d4fbcfa91ef2752bcc7070', '35a9c2fad935a2389a7b6e3a53d88ea476db611e', '071a6cd442706e424ea09bc8852eaa2e901c72f3', 'b743dafa3dcb8924244c14f0a719cde5e93d9155', '831d2fa6af688ef2d6b754bb315ef6cb20085763', '6a640438a4e50fa31943462eeca716413891a773', 'e9b7367c63ba970cc9a0360116b160dbe1eb1bb4', '755050d838b9b27d715c4bf1e8317294011fa5fc', '8663945d5090fe409e42af217ac19f77f69eee28', '1b9de2d1e74fbe49bf852fa495f63c31bb038a31', '595d0fe1c259c02069075d8c687210211908c3ed', 'abc7254b751b124ff98cbf522526cf2ce5376e95'], scores=array([25.882957 , 23.048092 , 21.487345 , 15.326306 , 15.008339 ,\n", + " 14.661501 , 14.5524025, 14.49628 , 14.228694 , 13.837463 ,\n", + " 13.796093 , 13.57072 , 13.536688 , 13.438 , 13.33793 ,\n", + " 13.2009535, 13.157523 , 13.150376 , 12.943943 , 12.894042 ,\n", + " 12.818537 , 12.801729 , 12.74225 , 12.502438 , 12.294656 ,\n", + " 12.225283 , 12.219966 , 12.075517 , 11.996811 , 11.959464 ,\n", + " 11.944234 , 11.833931 , 11.832078 , 11.817529 , 11.794199 ,\n", + " 11.759341 , 11.749658 , 11.721568 , 11.675982 , 11.661502 ,\n", + " 11.644913 , 11.623637 , 11.617721 , 11.418661 , 11.39984 ,\n", + " 11.393229 , 11.373691 , 11.372035 , 11.354653 , 11.342207 ,\n", + " 11.271814 , 11.204544 , 11.131927 , 11.116472 , 11.089201 ,\n", + " 11.083519 , 11.082279 , 11.05196 , 10.975056 , 10.940313 ,\n", + " 10.93409 , 10.8945465, 10.879114 , 10.864438 , 10.847451 ,\n", + " 10.832269 , 10.772312 , 10.77147 , 10.71236 , 10.706525 ,\n", + " 10.698383 , 10.68005 , 10.6775875, 10.643851 , 10.63102 ,\n", + " 10.624665 , 10.601247 , 10.597593 , 10.581311 , 10.578881 ,\n", + " 10.537812 , 10.488211 , 10.486176 , 10.47819 , 10.457276 ,\n", + " 10.447503 , 10.400571 , 10.382628 , 10.371334 , 10.359387 ,\n", + " 10.354765 , 10.342985 , 10.341277 , 10.3036785, 10.256121 ,\n", + " 10.221265 , 10.217081 , 10.216501 , 10.209621 , 10.163714 ],\n", + " dtype=float32), normalized=False)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "idx.search(queries[0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 297ed6828aaf5dafe069c90a37425b61cdae8613 Mon Sep 17 00:00:00 2001 From: Julius Lipp Date: Tue, 13 Aug 2024 13:50:07 -0700 Subject: [PATCH 2/2] Enhance loading + add notebooks --- baguetter/indices/dense/base.py | 4 +--- baguetter/indices/dense/faiss.py | 3 +-- baguetter/indices/dense/usearch.py | 4 ++-- baguetter/indices/sparse/models/bmx/index.py | 2 +- .../indices/sparse/text_preprocessor/text_processor.py | 2 +- baguetter/utils/persistable.py | 2 +- examples/hf_save_idx.ipynb | 8 ++++---- examples/save_idx.ipynb | 8 ++++---- tests/indices/dense/usearch_test.py | 9 ++++----- tests/indices/sparse/base_test.py | 4 ++-- 10 files changed, 21 insertions(+), 25 deletions(-) diff --git a/baguetter/indices/dense/base.py b/baguetter/indices/dense/base.py index e07cb14..4a0b3c3 100644 --- a/baguetter/indices/dense/base.py +++ b/baguetter/indices/dense/base.py @@ -12,7 +12,6 @@ import numpy as np -_STATE_PREFIX = "state_" _INDEX_PREFIX = "index_" @@ -70,11 +69,10 @@ def build_index_file_paths(path: str) -> tuple[str, str]: """ path = Path(path) - state_file_name = f"{_STATE_PREFIX}{path.name}" index_file_name = f"{_INDEX_PREFIX}{path.name}" dir_name = path.parent if path.parent != Path() else Path() - state_file_path = dir_name / state_file_name + state_file_path = dir_name / path.name index_file_path = dir_name / index_file_name return str(state_file_path), str(index_file_path) diff --git a/baguetter/indices/dense/faiss.py b/baguetter/indices/dense/faiss.py index 58cd664..e339ad7 100644 --- a/baguetter/indices/dense/faiss.py +++ b/baguetter/indices/dense/faiss.py @@ -5,7 +5,6 @@ import faiss import numpy as np -from tqdm import tqdm from baguetter.indices.base import SearchResults from baguetter.indices.dense.base import BaseDenseIndex @@ -113,7 +112,7 @@ def _save( with repository.open(index_file_path, "wb") as file: index = faiss.serialize_index(self.faiss_index) np.savez_compressed(file, index=index) - return path + return state_file_path @classmethod def _load( diff --git a/baguetter/indices/dense/usearch.py b/baguetter/indices/dense/usearch.py index 8f6b9b1..9ea6370 100644 --- a/baguetter/indices/dense/usearch.py +++ b/baguetter/indices/dense/usearch.py @@ -142,7 +142,7 @@ def _save( self, path: str, repository: AbstractFileRepository, - ) -> None: + ) -> str: """Save the index state and data. Args: @@ -171,7 +171,7 @@ def _save( Index.save(self, temp_file.name) file.write(temp_file.read()) - return path + return state_file_path @classmethod def _load( diff --git a/baguetter/indices/sparse/models/bmx/index.py b/baguetter/indices/sparse/models/bmx/index.py index 0522421..fd49a32 100644 --- a/baguetter/indices/sparse/models/bmx/index.py +++ b/baguetter/indices/sparse/models/bmx/index.py @@ -10,8 +10,8 @@ from collections import defaultdict from typing import TYPE_CHECKING -import numpy as np import numba as nb +import numpy as np from numpy import ndarray from baguetter.utils.common import tqdm diff --git a/baguetter/indices/sparse/text_preprocessor/text_processor.py b/baguetter/indices/sparse/text_preprocessor/text_processor.py index b5e55c2..d9add44 100644 --- a/baguetter/indices/sparse/text_preprocessor/text_processor.py +++ b/baguetter/indices/sparse/text_preprocessor/text_processor.py @@ -239,7 +239,7 @@ def process_many( Returns: Generator[list[str], None, None] | list[list[str]]: Processed text items as lists of tokens. """ - if n_workers <= 1: + if n_workers <= 0: processor = map(self._call_steps, items) else: processor = map_in_process( diff --git a/baguetter/utils/persistable.py b/baguetter/utils/persistable.py index db8b41c..d6baf55 100644 --- a/baguetter/utils/persistable.py +++ b/baguetter/utils/persistable.py @@ -53,7 +53,7 @@ def save(self, path: str) -> str: """Save the object to a local file repository. Args: - path (str, optional): Path to save the object to. + path (str): Path to save the object to. Returns: str: Path to the saved object. diff --git a/examples/hf_save_idx.ipynb b/examples/hf_save_idx.ipynb index ccf4ea0..66e42b6 100644 --- a/examples/hf_save_idx.ipynb +++ b/examples/hf_save_idx.ipynb @@ -53,15 +53,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "Tokenization: 100%|██████████| 25657/25657 [00:04<00:00, 5282.20it/s]\n", - "Building doc-term matrix: 100%|██████████| 25657/25657 [00:00<00:00, 62623.85it/s]\n", - "Building inverted index: 100%|██████████| 61627/61627 [00:04<00:00, 14888.04it/s]\n" + "Tokenization: 100%|██████████| 25657/25657 [00:08<00:00, 2902.54it/s]\n", + "Building doc-term matrix: 100%|██████████| 25657/25657 [00:00<00:00, 47972.01it/s]\n", + "Building inverted index: 100%|██████████| 61627/61627 [00:04<00:00, 14397.29it/s]\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, diff --git a/examples/save_idx.ipynb b/examples/save_idx.ipynb index fb94d32..469f6bd 100644 --- a/examples/save_idx.ipynb +++ b/examples/save_idx.ipynb @@ -53,15 +53,15 @@ "name": "stderr", "output_type": "stream", "text": [ - "Tokenization: 100%|██████████| 25657/25657 [00:04<00:00, 5193.93it/s]\n", - "Building doc-term matrix: 100%|██████████| 25657/25657 [00:00<00:00, 54218.01it/s]\n", - "Building inverted index: 100%|██████████| 61627/61627 [00:04<00:00, 14276.33it/s]\n" + "Tokenization: 100%|██████████| 25657/25657 [00:08<00:00, 3161.79it/s]\n", + "Building doc-term matrix: 100%|██████████| 25657/25657 [00:00<00:00, 61635.90it/s]\n", + "Building inverted index: 100%|██████████| 61627/61627 [00:04<00:00, 14483.47it/s]\n" ] }, { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 3, diff --git a/tests/indices/dense/usearch_test.py b/tests/indices/dense/usearch_test.py index 53360c0..b544e97 100644 --- a/tests/indices/dense/usearch_test.py +++ b/tests/indices/dense/usearch_test.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from baguetter.indices.dense.base import _INDEX_PREFIX, _STATE_PREFIX +from baguetter.indices.dense.base import _INDEX_PREFIX from baguetter.indices.dense.config import DenseIndexConfig from baguetter.indices.dense.usearch import USearchDenseIndex from baguetter.utils.file_repository import LocalFileRepository @@ -20,7 +20,6 @@ def sample_data(): def test_usearch_index_creation(): index = USearchDenseIndex(embedding_dim=128) - assert index.name.startswith("usearch_") assert index.config.embedding_dim == 128 @@ -98,10 +97,10 @@ def test_usearch_save_load(sample_data): repository = LocalFileRepository(tmp_dir) # Save the index - index._save(repository, save_path) + index._save(path=save_path, repository=repository) # Load the index - loaded_index = USearchDenseIndex._load(save_path, repository=repository) + loaded_index = USearchDenseIndex._load(path=save_path, repository=repository) assert loaded_index.config.embedding_dim == 128 assert loaded_index.key_counter == 10 @@ -117,7 +116,7 @@ def test_usearch_save_load(sample_data): ) # Verify file names - assert repository.exists(f"{_STATE_PREFIX}{save_path}") + assert repository.exists(save_path) assert repository.exists(f"{_INDEX_PREFIX}{save_path}") diff --git a/tests/indices/sparse/base_test.py b/tests/indices/sparse/base_test.py index 8e2e944..1566694 100644 --- a/tests/indices/sparse/base_test.py +++ b/tests/indices/sparse/base_test.py @@ -57,7 +57,7 @@ def mock_docs(): def test_constructor(mock_sparse_index): - assert mock_sparse_index.name == "sparse_test-index" + assert mock_sparse_index.name == "test-index" assert isinstance(mock_sparse_index._pre_processor, TextPreprocessor) assert isinstance(mock_sparse_index.config, SparseIndexConfig) assert mock_sparse_index.index is None @@ -67,7 +67,7 @@ def test_constructor(mock_sparse_index): def test_constructor_from_config(mock_sparse_index_from_config): - assert mock_sparse_index_from_config.name == "sparse_test-index" + assert mock_sparse_index_from_config.name == "test-index" assert isinstance(mock_sparse_index_from_config._pre_processor, TextPreprocessor) assert isinstance(mock_sparse_index_from_config.config, SparseIndexConfig) assert mock_sparse_index_from_config.index is None