Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize extensions redis init #188

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
vectorizer: Optional[BaseVectorizer] = None,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_args: Dict[str, Any] = {},
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Semantic Cache for Large Language Models.
Expand All @@ -43,14 +43,13 @@ def __init__(
cache. Defaults to 0.1.
ttl (Optional[int], optional): The time-to-live for records cached
in Redis. Defaults to None.
vectorizer (BaseVectorizer, optional): The vectorizer for the cache.
vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
Defaults to HFTextVectorizer.
redis_client(Redis, optional): A redis client connection instance.
redis_client(Optional[Redis], optional): A redis client connection instance.
Defaults to None.
redis_url (str, optional): The redis url. Defaults to
"redis://localhost:6379".
connection_args (Dict[str, Any], optional): The connection arguments
for the redis client. Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.

Raises:
TypeError: If an invalid vectorizer is provided.
Expand Down Expand Up @@ -96,8 +95,8 @@ def __init__(
# handle redis connection
if redis_client:
self._index.set_client(redis_client)
else:
self._index.connect(redis_url=redis_url, **connection_args)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)

# initialize other components
self.default_return_fields = [
Expand Down
34 changes: 7 additions & 27 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,9 @@ def __init__(
vectorizer: Optional[BaseVectorizer] = None,
routing_config: Optional[RoutingConfig] = None,
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize the SemanticRouter.
Expand All @@ -98,9 +99,10 @@ def __init__(
vectorizer (BaseVectorizer, optional): The vectorizer used to embed route references. Defaults to default HFTextVectorizer.
routing_config (RoutingConfig, optional): Configuration for routing behavior. Defaults to the default RoutingConfig.
redis_client (Optional[Redis], optional): Redis client for connection. Defaults to None.
redis_url (Optional[str], optional): Redis URL for connection. Defaults to None.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
**kwargs: Additional arguments.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.
"""
# Set vectorizer default
if vectorizer is None:
Expand All @@ -115,12 +117,12 @@ def __init__(
vectorizer=vectorizer,
routing_config=routing_config,
)
self._initialize_index(redis_client, redis_url, overwrite)
self._initialize_index(redis_client, redis_url, overwrite, **connection_kwargs)

def _initialize_index(
self,
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
redis_url: str = "redis://localhost:6379",
overwrite: bool = False,
**connection_kwargs,
):
Expand All @@ -132,8 +134,6 @@ def _initialize_index(
self._index.set_client(redis_client)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)
else:
raise ValueError("Must provide either a redis client or redis url string.")

existed = self._index.exists()
self._index.create(overwrite=overwrite)
Expand Down Expand Up @@ -479,19 +479,12 @@ def clear(self) -> None:
def from_dict(
cls,
data: Dict[str, Any],
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
overwrite: bool = False,
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a dictionary.

Args:
data (Dict[str, Any]): The dictionary containing the semantic router data.
redis_client (Optional[Redis]): Redis client for connection.
redis_url (Optional[str]): Redis URL for connection.
overwrite (bool): Whether to overwrite existing index.
**kwargs: Additional arguments.

Returns:
SemanticRouter: The semantic router instance.
Expand Down Expand Up @@ -533,9 +526,6 @@ def from_dict(
routes=routes,
vectorizer=vectorizer,
routing_config=routing_config,
redis_client=redis_client,
redis_url=redis_url,
overwrite=overwrite,
**kwargs,
)

Expand Down Expand Up @@ -565,19 +555,12 @@ def to_dict(self) -> Dict[str, Any]:
def from_yaml(
cls,
file_path: str,
redis_client: Optional[Redis] = None,
redis_url: Optional[str] = None,
overwrite: bool = False,
**kwargs,
) -> "SemanticRouter":
"""Create a SemanticRouter from a YAML file.

Args:
file_path (str): The path to the YAML file.
redis_client (Optional[Redis]): Redis client for connection.
redis_url (Optional[str]): Redis URL for connection.
overwrite (bool): Whether to overwrite existing index.
**kwargs: Additional arguments.

Returns:
SemanticRouter: The semantic router instance.
Expand All @@ -603,9 +586,6 @@ def from_yaml(
yaml_data = yaml.safe_load(f)
return cls.from_dict(
yaml_data,
redis_client=redis_client,
redis_url=redis_url,
overwrite=overwrite,
**kwargs,
)

Expand Down
15 changes: 10 additions & 5 deletions redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from time import time
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from redis import Redis

Expand Down Expand Up @@ -27,6 +27,8 @@ def __init__(
distance_threshold: float = 0.3,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize session memory with index

Expand All @@ -43,12 +45,14 @@ def __init__(
user_tag (str): Tag to be added to entries to link to a specific user.
prefix (Optional[str]): Prefix for the keys for this session data.
Defaults to None and will be replaced with the index name.
vectorizer (Vectorizer): The vectorizer to create embeddings with.
vectorizer (Optional[BaseVectorizer]): The vectorizer used to create embeddings.
distance_threshold (float): The maximum semantic distance to be
included in the context. Defaults to 0.3.
redis_client (Optional[Redis]): A Redis client instance. Defaults to
None.
redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.

The proposed schema will support a single vector embedding constructed
from either the prompt or response in a single string.
Expand Down Expand Up @@ -89,10 +93,11 @@ def __init__(

self._index = SearchIndex(schema=schema)

# handle redis connection
if redis_client:
self._index.set_client(redis_client)
else:
self._index.connect(redis_url=redis_url)
elif redis_url:
self._index.connect(redis_url=redis_url, **connection_kwargs)

self._index.create(overwrite=False)

Expand Down
19 changes: 14 additions & 5 deletions redisvl/extensions/session_manager/standard_session.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import json
from time import time
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from redis import Redis

from redisvl.extensions.session_manager import BaseSessionManager
from redisvl.redis.connection import RedisConnectionFactory


class StandardSessionManager(BaseSessionManager):
Expand All @@ -16,6 +17,8 @@ def __init__(
user_tag: str,
redis_client: Optional[Redis] = None,
redis_url: str = "redis://localhost:6379",
connection_kwargs: Dict[str, Any] = {},
**kwargs,
):
"""Initialize session memory

Expand All @@ -31,18 +34,24 @@ def __init__(
user_tag (str): Tag to be added to entries to link to a specific user.
redis_client (Optional[Redis]): A Redis client instance. Defaults to
None.
redis_url (str): The URL of the Redis instance. Defaults to 'redis://localhost:6379'.
redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
connection_kwargs (Dict[str, Any]): The connection arguments
for the redis client. Defaults to empty {}.

The proposed schema will support a single combined vector embedding
constructed from the prompt & response in a single string.

"""
super().__init__(name, session_tag, user_tag)

# handle redis connection
if redis_client:
self._client = redis_client
else:
self._client = Redis.from_url(redis_url)
elif redis_url:
self._client = RedisConnectionFactory.get_redis_connection(
redis_url, **connection_kwargs
)
RedisConnectionFactory.validate_redis(self._client)

self.set_scope(session_tag, user_tag)

Expand All @@ -51,7 +60,7 @@ def set_scope(
session_tag: Optional[str] = None,
user_tag: Optional[str] = None,
) -> None:
"""Set the filter to apply to querries based on the desired scope.
"""Set the filter to apply to queries based on the desired scope.

This new scope persists until another call to set_scope is made, or if
scope is specified in calls to get_recent.
Expand Down
26 changes: 11 additions & 15 deletions tests/integration/test_llmcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from time import sleep

import pytest
from redis.exceptions import ConnectionError

from redisvl.extensions.llmcache import SemanticCache
from redisvl.index.index import SearchIndex
Expand Down Expand Up @@ -40,19 +41,17 @@ def cache_with_ttl(vectorizer, redis_url):


@pytest.fixture
def cache_with_redis_client(vectorizer, client, redis_url):
def cache_with_redis_client(vectorizer, client):
cache_instance = SemanticCache(
vectorizer=vectorizer,
redis_client=client,
distance_threshold=0.2,
redis_url=redis_url,
)
yield cache_instance
cache_instance.clear() # Clear cache after each test
cache_instance._index.delete(True) # Clean up index


# # Test handling invalid input for check method
def test_bad_ttl(cache):
with pytest.raises(ValueError):
cache.set_ttl(2.5)
Expand All @@ -76,7 +75,6 @@ def test_reset_ttl(cache):
assert cache.ttl is None


# Test basic store and check functionality
def test_store_and_check(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -91,7 +89,6 @@ def test_store_and_check(cache, vectorizer):
assert "metadata" not in check_result[0]


# Test clearing the cache
def test_clear(cache, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand Down Expand Up @@ -139,7 +136,6 @@ def test_check_no_match(cache, vectorizer):
assert len(check_result) == 0


# Test handling invalid input for check method
def test_check_invalid_input(cache):
with pytest.raises(ValueError):
cache.check()
Expand All @@ -148,7 +144,15 @@ def test_check_invalid_input(cache):
cache.check(prompt="test", return_fields="bad value")


# Test storing with metadata
def test_bad_connection_info(vectorizer):
with pytest.raises(ConnectionError):
SemanticCache(
vectorizer=vectorizer,
distance_threshold=0.2,
redis_url="redis://localhost:6389",
)


def test_store_with_metadata(cache, vectorizer):
prompt = "This is another test prompt."
response = "This is another test response."
Expand All @@ -165,7 +169,6 @@ def test_store_with_metadata(cache, vectorizer):
assert check_result[0]["prompt"] == prompt


# Test storing with invalid metadata
def test_store_with_invalid_metadata(cache, vectorizer):
prompt = "This is another test prompt."
response = "This is another test response."
Expand All @@ -179,7 +182,6 @@ def test_store_with_invalid_metadata(cache, vectorizer):
cache.store(prompt, response, vector=vector, metadata=metadata)


# Test setting and getting the distance threshold
def test_distance_threshold(cache):
initial_threshold = cache.distance_threshold
new_threshold = 0.1
Expand All @@ -189,14 +191,12 @@ def test_distance_threshold(cache):
assert cache.distance_threshold != initial_threshold


# Test out of range distance threshold
def test_distance_threshold_out_of_range(cache):
out_of_range_threshold = -1
with pytest.raises(ValueError):
cache.set_threshold(out_of_range_threshold)


# Test storing and retrieving multiple items
def test_multiple_items(cache, vectorizer):
prompts_responses = {
"prompt1": "response1",
Expand All @@ -217,12 +217,10 @@ def test_multiple_items(cache, vectorizer):
assert "metadata" not in check_result[0]


# Test retrieving underlying SearchIndex for the cache.
def test_get_index(cache):
assert isinstance(cache.index, SearchIndex)


# Test basic functionality with cache created with user-provided Redis client
def test_store_and_check_with_provided_client(cache_with_redis_client, vectorizer):
prompt = "This is a test prompt."
response = "This is a test response."
Expand All @@ -237,13 +235,11 @@ def test_store_and_check_with_provided_client(cache_with_redis_client, vectorize
assert "metadata" not in check_result[0]


# Test deleting the cache
def test_delete(cache_no_cleanup):
cache_no_cleanup.delete()
assert not cache_no_cleanup.index.exists()


# Test we can only store and check vectors of correct dimensions
def test_vector_size(cache, vectorizer):
prompt = "This is test prompt."
response = "This is a test response."
Expand Down
Loading
Loading