Skip to content

Commit

Permalink
Merge pull request #25 from provos/persisting-state
Browse files Browse the repository at this point in the history
Add save() and load() methods to unisim so that embeddings can be reused
  • Loading branch information
ebursztein authored Dec 20, 2024
2 parents 84110d9 + c47b965 commit 4ef52b7
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
69 changes: 69 additions & 0 deletions tests/test_unisim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import pytest

from unisim.embedder import Embedder
from unisim.unisim import UniSim


class DummyEmbedder(Embedder):
def __init__(self):
# Skip parent class init to avoid loading model
self.batch_size = 2
self.model_id = "dummy"
self.verbose = 0

@property
def embedding_size(self) -> int:
return 3

def embed(self, inputs):
# Simple mock embedder that returns fixed embeddings
return np.array([[1, 1, 3], [3, 1, 2]], dtype="float32")[: len(inputs)]

def predict(self, data):
# Override predict to avoid using model
return self.embed(data)


index_type = ["exact", "approx"]


def set_up_test_unisim(index_type):
unisim = UniSim(
store_data=True,
index_type=index_type,
return_embeddings=True,
batch_size=2,
use_accelerator=False,
model_id="test",
embedder=DummyEmbedder(),
)
# Add some test data - needs to be two items to match the mock embedder
inputs = ["test1", "test2"]
unisim.add(inputs)
return unisim


@pytest.mark.parametrize("index_type", index_type, ids=index_type)
def test_unisim_save_load(index_type, tmp_path):
# Set up original unisim instance
unisim = set_up_test_unisim(index_type)

# Save state to temporary directory
prefix = str(tmp_path / "unisim_test")
unisim.save(prefix)

# Create new instance and restore from saved files
new_unisim = set_up_test_unisim(index_type)
new_unisim.load(prefix)

# Verify search works correctly after restoration
queries = ["query1"]
results = new_unisim.search(queries=queries, similarity_threshold=0.5, k=2)

# Verify results
assert results.total_matches > 0
result = results.results[0]
assert result.query_data == "query1"
assert len(result.matches) == 2
assert result.matches[0].data in ["test1", "test2"]
59 changes: 59 additions & 0 deletions unisim/unisim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.

import json
import logging
from abc import ABC
from typing import Any, Dict, List, Sequence
Expand Down Expand Up @@ -321,3 +322,61 @@ def info(self):
print(f"|-use_accelerator: {self.use_accelerator}")
print(f"|-store index data: {self.store_data}")
print(f"|-return embeddings: {self.return_embeddings}")

def save(self, prefix: str) -> None:
"""Save UniSim state to disk using the given filename prefix.
For exact indexing:
- Saves embeddings to {prefix}.embeddings as numpy array
For approx indexing:
- Saves index to {prefix}.usearch using USearch format
If store_data=True, saves data to {prefix}.data as JSON
Args:
prefix: Filename prefix for saved state files
"""
# Save embeddings/index
if self.index_type == IndexerType.exact:
embeddings = np.array(self.indexer.embeddings)
np.save(f"{prefix}.embeddings", embeddings)
elif self.index_type == IndexerType.approx:
self.indexer.index.save(f"{prefix}.usearch")

# Save data if requested
if self.store_data:
with open(f"{prefix}.data", "w", encoding="utf-8") as f:
json.dump(self.indexed_data, f)

def load(self, prefix: str) -> None:
"""Load UniSim state from disk using the given filename prefix.
For exact indexing:
- Loads embeddings from {prefix}.embeddings as numpy array
For approx indexing:
- Loads index from {prefix}.usearch using USearch format
If store_data=True, loads data from {prefix}.data as JSON
Args:
prefix: Filename prefix for saved state files
"""
self.reset_index()

# Load embeddings/index
if self.index_type == IndexerType.exact:
embeddings = np.load(f"{prefix}.embeddings.npy")
for i in range(0, len(embeddings), self.batch_size):
batch = embeddings[i : i + self.batch_size]
self.indexer.add(batch, list(range(i, i + len(batch))))
self.index_size = len(embeddings)
elif self.index_type == IndexerType.approx:
self.indexer.index.load(f"{prefix}.usearch")
self.index_size = self.indexer.index.size

# Load data if requested
if self.store_data:
with open(f"{prefix}.data", "r", encoding="utf-8") as f:
self.indexed_data = json.load(f)

0 comments on commit 4ef52b7

Please sign in to comment.