Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds check to ensure correct embedding vector dimensions are used #177

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ def _search_cache(
)
return cache_hits

def _check_vector_dims(self, vector: List[float]):
"""Checks the size of the provided vector and raises an error if it
doesn't match the search index vector dimensions."""
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
if schema_vector_dims != len(vector):
raise ValueError(
"Invalid vector dimensions! "
f"Vector has dims defined as {len(vector)}",
f"Vector field has dims defined as {schema_vector_dims}",
)

def check(
self,
prompt: Optional[str] = None,
Expand Down Expand Up @@ -266,6 +277,7 @@ def check(

Raises:
ValueError: If neither a `prompt` nor a `vector` is specified.
ValueError: if 'vector' has incorrect dimensions.
TypeError: If `return_fields` is not a list when provided.

.. code-block:: python
Expand All @@ -279,6 +291,7 @@ def check(

# Use provided vector or create from prompt
vector = vector or self._vectorize_prompt(prompt)
self._check_vector_dims(vector)

# Check for cache hits by searching the cache
cache_hits = self._search_cache(vector, num_results, return_fields)
Expand Down Expand Up @@ -307,6 +320,7 @@ def store(

Raises:
ValueError: If neither prompt nor vector is specified.
ValueError: if vector has incorrect dimensions.
TypeError: If provided metadata is not a dictionary.

.. code-block:: python
Expand All @@ -319,6 +333,8 @@ def store(
"""
# Vectorize prompt if necessary and create cache payload
vector = vector or self._vectorize_prompt(prompt)
self._check_vector_dims(vector)

# Construct semantic cache payload
id_field = self.entry_id_field_name
payload = {
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,31 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize
def test_delete(cache_no_cleanup):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()


# Test we can only store and check vectors of correct dimensions
def test_vector_size(cache, vectorizer):
prompt = "This is test prompt."
response = "This is a test response."

vector = vectorizer.embed(prompt)
cache.store(prompt=prompt, response=response, vector=vector)

# Test we can query with modified embeddings of correct size
vector_2 = [v * 0.99 for v in vector] # same dimensions
check_result = cache.check(vector=vector_2)
assert check_result[0]["prompt"] == prompt

# Test that error is raised when we try to load wrong size vectors
with pytest.raises(ValueError):
cache.store(prompt=prompt, response=response, vector=vector[0:-1])

with pytest.raises(ValueError):
cache.store(prompt=prompt, response=response, vector=[1, 2, 3])

# Test that error is raised when we try to query with wrong size vector
with pytest.raises(ValueError):
cache.check(vector=vector[0:-1])

with pytest.raises(ValueError):
cache.check(vector=[1, 2, 3])
Loading