Skip to content

Commit

Permalink
Check to ensure correct embedding vector dimensions are used (#177)
Browse files Browse the repository at this point in the history
Currently our semantic cache allows for specifying the vector in calls
to store() and check(), but if the vector dimension does not match the
schema dimensions this fails silently. This PR adds a check to verify
correct vector dimensions and raises an error if they do not match.
  • Loading branch information
justin-cechmanek authored Jul 3, 2024
1 parent ccc039f commit aa05797
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
16 changes: 16 additions & 0 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ def _search_cache(
)
return cache_hits

def _check_vector_dims(self, vector: List[float]):
"""Checks the size of the provided vector and raises an error if it
doesn't match the search index vector dimensions."""
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
if schema_vector_dims != len(vector):
raise ValueError(
"Invalid vector dimensions! "
f"Vector has dims defined as {len(vector)}",
f"Vector field has dims defined as {schema_vector_dims}",
)

def check(
self,
prompt: Optional[str] = None,
Expand Down Expand Up @@ -266,6 +277,7 @@ def check(
Raises:
ValueError: If neither a `prompt` nor a `vector` is specified.
ValueError: if 'vector' has incorrect dimensions.
TypeError: If `return_fields` is not a list when provided.
.. code-block:: python
Expand All @@ -279,6 +291,7 @@ def check(

# Use provided vector or create from prompt
vector = vector or self._vectorize_prompt(prompt)
self._check_vector_dims(vector)

# Check for cache hits by searching the cache
cache_hits = self._search_cache(vector, num_results, return_fields)
Expand Down Expand Up @@ -307,6 +320,7 @@ def store(
Raises:
ValueError: If neither prompt nor vector is specified.
ValueError: if vector has incorrect dimensions.
TypeError: If provided metadata is not a dictionary.
.. code-block:: python
Expand All @@ -319,6 +333,8 @@ def store(
"""
# Vectorize prompt if necessary and create cache payload
vector = vector or self._vectorize_prompt(prompt)
self._check_vector_dims(vector)

# Construct semantic cache payload
id_field = self.entry_id_field_name
payload = {
Expand Down
28 changes: 28 additions & 0 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,31 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize
def test_delete(cache_no_cleanup):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()


# Test we can only store and check vectors of correct dimensions
def test_vector_size(cache, vectorizer):
prompt = "This is test prompt."
response = "This is a test response."

vector = vectorizer.embed(prompt)
cache.store(prompt=prompt, response=response, vector=vector)

# Test we can query with modified embeddings of correct size
vector_2 = [v * 0.99 for v in vector] # same dimensions
check_result = cache.check(vector=vector_2)
assert check_result[0]["prompt"] == prompt

# Test that error is raised when we try to load wrong size vectors
with pytest.raises(ValueError):
cache.store(prompt=prompt, response=response, vector=vector[0:-1])

with pytest.raises(ValueError):
cache.store(prompt=prompt, response=response, vector=[1, 2, 3])

# Test that error is raised when we try to query with wrong size vector
with pytest.raises(ValueError):
cache.check(vector=vector[0:-1])

with pytest.raises(ValueError):
cache.check(vector=[1, 2, 3])

0 comments on commit aa05797

Please sign in to comment.