diff --git a/tests/test_unisim.py b/tests/test_unisim.py new file mode 100644 index 0000000..026046c --- /dev/null +++ b/tests/test_unisim.py @@ -0,0 +1,69 @@ +import numpy as np +import pytest + +from unisim.embedder import Embedder +from unisim.unisim import UniSim + + +class DummyEmbedder(Embedder): + def __init__(self): + # Skip parent class init to avoid loading model + self.batch_size = 2 + self.model_id = "dummy" + self.verbose = 0 + + @property + def embedding_size(self) -> int: + return 3 + + def embed(self, inputs): + # Simple mock embedder that returns fixed embeddings + return np.array([[1, 1, 3], [3, 1, 2]], dtype="float32")[: len(inputs)] + + def predict(self, data): + # Override predict to avoid using model + return self.embed(data) + + +index_type = ["exact", "approx"] + + +def set_up_test_unisim(index_type): + unisim = UniSim( + store_data=True, + index_type=index_type, + return_embeddings=True, + batch_size=2, + use_accelerator=False, + model_id="test", + embedder=DummyEmbedder(), + ) + # Add some test data - needs to be two items to match the mock embedder + inputs = ["test1", "test2"] + unisim.add(inputs) + return unisim + + +@pytest.mark.parametrize("index_type", index_type, ids=index_type) +def test_unisim_save_load(index_type, tmp_path): + # Set up original unisim instance + unisim = set_up_test_unisim(index_type) + + # Save state to temporary directory + prefix = str(tmp_path / "unisim_test") + unisim.save(prefix) + + # Create new instance and restore from saved files + new_unisim = set_up_test_unisim(index_type) + new_unisim.load(prefix) + + # Verify search works correctly after restoration + queries = ["query1"] + results = new_unisim.search(queries=queries, similarity_threshold=0.5, k=2) + + # Verify results + assert results.total_matches > 0 + result = results.results[0] + assert result.query_data == "query1" + assert len(result.matches) == 2 + assert result.matches[0].data in ["test1", "test2"] diff --git a/unisim/unisim.py b/unisim/unisim.py index 8f35dbe..77522d3 100644 --- a/unisim/unisim.py +++ b/unisim/unisim.py @@ -4,6 +4,7 @@ # license that can be found in the LICENSE file or at # https://opensource.org/licenses/MIT. +import json import logging from abc import ABC from typing import Any, Dict, List, Sequence @@ -321,3 +322,61 @@ def info(self): print(f"|-use_accelerator: {self.use_accelerator}") print(f"|-store index data: {self.store_data}") print(f"|-return embeddings: {self.return_embeddings}") + + def save(self, prefix: str) -> None: + """Save UniSim state to disk using the given filename prefix. + + For exact indexing: + - Saves embeddings to {prefix}.embeddings as numpy array + + For approx indexing: + - Saves index to {prefix}.usearch using USearch format + + If store_data=True, saves data to {prefix}.data as JSON + + Args: + prefix: Filename prefix for saved state files + """ + # Save embeddings/index + if self.index_type == IndexerType.exact: + embeddings = np.array(self.indexer.embeddings) + np.save(f"{prefix}.embeddings", embeddings) + elif self.index_type == IndexerType.approx: + self.indexer.index.save(f"{prefix}.usearch") + + # Save data if requested + if self.store_data: + with open(f"{prefix}.data", "w", encoding="utf-8") as f: + json.dump(self.indexed_data, f) + + def load(self, prefix: str) -> None: + """Load UniSim state from disk using the given filename prefix. + + For exact indexing: + - Loads embeddings from {prefix}.embeddings as numpy array + + For approx indexing: + - Loads index from {prefix}.usearch using USearch format + + If store_data=True, loads data from {prefix}.data as JSON + + Args: + prefix: Filename prefix for saved state files + """ + self.reset_index() + + # Load embeddings/index + if self.index_type == IndexerType.exact: + embeddings = np.load(f"{prefix}.embeddings.npy") + for i in range(0, len(embeddings), self.batch_size): + batch = embeddings[i : i + self.batch_size] + self.indexer.add(batch, list(range(i, i + len(batch)))) + self.index_size = len(embeddings) + elif self.index_type == IndexerType.approx: + self.indexer.index.load(f"{prefix}.usearch") + self.index_size = self.indexer.index.size + + # Load data if requested + if self.store_data: + with open(f"{prefix}.data", "r", encoding="utf-8") as f: + self.indexed_data = json.load(f)