diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 07602ac7..3956e7d6 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -234,6 +234,17 @@ def _search_cache( ) return cache_hits + def _check_vector_dims(self, vector: List[float]): + """Checks the size of the provided vector and raises an error if it + doesn't match the search index vector dimensions.""" + schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore + if schema_vector_dims != len(vector): + raise ValueError( + "Invalid vector dimensions! " + f"Vector has dims defined as {len(vector)}", + f"Vector field has dims defined as {schema_vector_dims}", + ) + def check( self, prompt: Optional[str] = None, @@ -266,6 +277,7 @@ def check( Raises: ValueError: If neither a `prompt` nor a `vector` is specified. + ValueError: if 'vector' has incorrect dimensions. TypeError: If `return_fields` is not a list when provided. .. code-block:: python @@ -279,6 +291,7 @@ def check( # Use provided vector or create from prompt vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) # Check for cache hits by searching the cache cache_hits = self._search_cache(vector, num_results, return_fields) @@ -307,6 +320,7 @@ def store( Raises: ValueError: If neither prompt nor vector is specified. + ValueError: if vector has incorrect dimensions. TypeError: If provided metadata is not a dictionary. .. code-block:: python @@ -319,6 +333,8 @@ def store( """ # Vectorize prompt if necessary and create cache payload vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) + # Construct semantic cache payload id_field = self.entry_id_field_name payload = { diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 373ca8da..2bb107fd 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -210,3 +210,31 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize def test_delete(cache_no_cleanup): cache_no_cleanup.delete() assert not cache_no_cleanup.index.exists() + + +# Test we can only store and check vectors of correct dimensions +def test_vector_size(cache, vectorizer): + prompt = "This is test prompt." + response = "This is a test response." + + vector = vectorizer.embed(prompt) + cache.store(prompt=prompt, response=response, vector=vector) + + # Test we can query with modified embeddings of correct size + vector_2 = [v * 0.99 for v in vector] # same dimensions + check_result = cache.check(vector=vector_2) + assert check_result[0]["prompt"] == prompt + + # Test that error is raised when we try to load wrong size vectors + with pytest.raises(ValueError): + cache.store(prompt=prompt, response=response, vector=vector[0:-1]) + + with pytest.raises(ValueError): + cache.store(prompt=prompt, response=response, vector=[1, 2, 3]) + + # Test that error is raised when we try to query with wrong size vector + with pytest.raises(ValueError): + cache.check(vector=vector[0:-1]) + + with pytest.raises(ValueError): + cache.check(vector=[1, 2, 3])