diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index c01f9254..11b88dc6 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -3,6 +3,8 @@ from pydantic.v1 import BaseModel, Field, validator +from redisvl.schema import IndexInfo, IndexSchema + class Route(BaseModel): """Model representing a routing path with associated metadata and thresholds.""" @@ -80,3 +82,36 @@ def distance_threshold_must_be_valid(cls, v): if v <= 0 or v > 1: raise ValueError("distance_threshold must be between 0 and 1") return v + + +class SemanticRouterIndexSchema(IndexSchema): + """Customized index schema for SemanticRouter.""" + + @classmethod + def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": + """Create an index schema based on router name and vector dimensions. + + Args: + name (str): The name of the index. + vector_dims (int): The dimensions of the vectors. + + Returns: + SemanticRouterIndexSchema: The constructed index schema. + """ + return cls( + index=IndexInfo(name=name, prefix=name), + fields=[ # type: ignore + {"name": "route_name", "type": "tag"}, + {"name": "reference", "type": "text"}, + { + "name": "vector", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": vector_dims, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + ) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 1c34202a..bab69578 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -13,11 +13,11 @@ Route, RouteMatch, RoutingConfig, + SemanticRouterIndexSchema, ) from redisvl.index import SearchIndex from redisvl.query import RangeQuery from redisvl.redis.utils import convert_bytes, hashify, make_dict -from redisvl.schema import IndexInfo, IndexSchema from redisvl.utils.log import get_logger from redisvl.utils.utils import model_to_dict from redisvl.utils.vectorize import ( @@ -29,39 +29,6 @@ logger = get_logger(__name__) -class SemanticRouterIndexSchema(IndexSchema): - """Customized index schema for SemanticRouter.""" - - @classmethod - def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": - """Create an index schema based on router name and vector dimensions. - - Args: - name (str): The name of the index. - vector_dims (int): The dimensions of the vectors. - - Returns: - SemanticRouterIndexSchema: The constructed index schema. - """ - return cls( - index=IndexInfo(name=name, prefix=name), - fields=[ # type: ignore - {"name": "route_name", "type": "tag"}, - {"name": "reference", "type": "text"}, - { - "name": "vector", - "type": "vector", - "attrs": { - "algorithm": "flat", - "dims": vector_dims, - "distance_metric": "cosine", - "datatype": "float32", - }, - }, - ], - ) - - class SemanticRouter(BaseModel): """Semantic Router for managing and querying route vectors.""" diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index 97c4a9f1..c2cddb2b 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -1,9 +1,7 @@ from typing import Any, Dict, List, Optional, Union -from uuid import uuid4 -from redis import Redis - -from redisvl.query.filter import FilterExpression +from redisvl.extensions.session_manager.schema import ChatMessage +from redisvl.utils.utils import create_uuid class BaseSessionManager: @@ -32,7 +30,7 @@ def __init__( session. Defaults to instance uuid. """ self._name = name - self._session_tag = session_tag or uuid4().hex + self._session_tag = session_tag or create_uuid() def clear(self) -> None: """Clears the chat session history.""" @@ -85,14 +83,13 @@ def get_recent( raise NotImplementedError def _format_context( - self, hits: List[Dict[str, Any]], as_text: bool + self, messages: List[Dict[str, Any]], as_text: bool ) -> Union[List[str], List[Dict[str, str]]]: """Extracts the prompt and response fields from the Redis hashes and formats them as either flat dictionaries or strings. Args: - hits (List): The hashes containing prompt & response pairs from - recent conversation history. + messages (List[Dict[str, Any]]): The messages from the session index. as_text (bool): Whether to return the conversation as a single string, or list of alternating prompts and responses. @@ -100,29 +97,25 @@ def _format_context( Union[str, List[str]]: A single string transcription of the session or list of strings if as_text is false. """ - if as_text: - text_statements = [] - for hit in hits: - text_statements.append(hit[self.content_field_name]) - return text_statements - else: - statements = [] - for hit in hits: - statements.append( - { - self.role_field_name: hit[self.role_field_name], - self.content_field_name: hit[self.content_field_name], - } - ) - if ( - hasattr(hit, self.tool_field_name) - or isinstance(hit, dict) - and self.tool_field_name in hit - ): - statements[-1].update( - {self.tool_field_name: hit[self.tool_field_name]} - ) - return statements + context = [] + + for message in messages: + + chat_message = ChatMessage(**message) + + if as_text: + context.append(chat_message.content) + else: + chat_message_dict = { + self.role_field_name: chat_message.role, + self.content_field_name: chat_message.content, + } + if chat_message.tool_call_id is not None: + chat_message_dict[self.tool_field_name] = chat_message.tool_call_id + + context.append(chat_message_dict) # type: ignore + + return context def store( self, prompt: str, response: str, session_tag: Optional[str] = None diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py new file mode 100644 index 00000000..4394d4cb --- /dev/null +++ b/redisvl/extensions/session_manager/schema.py @@ -0,0 +1,94 @@ +from typing import Dict, List, Optional + +from pydantic.v1 import BaseModel, Field, root_validator + +from redisvl.redis.utils import array_to_buffer +from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp + + +class ChatMessage(BaseModel): + """A single chat message exchanged between a user and an LLM.""" + + _id: Optional[str] = Field(default=None) + """A unique identifier for the message.""" + role: str # TODO -- do we enumify this? + """The role of the message sender (e.g., 'user' or 'llm').""" + content: str + """The content of the message.""" + session_tag: str + """Tag associated with the current session.""" + timestamp: float = Field(default_factory=current_timestamp) + """The time the message was sent, in UTC, rounded to milliseconds.""" + tool_call_id: Optional[str] = Field(default=None) + """An optional identifier for a tool call associated with the message.""" + vector_field: Optional[List[float]] = Field(default=None) + """The vector representation of the message content.""" + + class Config: + arbitrary_types_allowed = True + + @root_validator(pre=False) + @classmethod + def generate_id(cls, values): + if "_id" not in values: + values["_id"] = f'{values["session_tag"]}:{values["timestamp"]}' + return values + + def to_dict(self) -> Dict: + data = self.dict() + + # handle optional fields + if data["vector_field"] is not None: + data["vector_field"] = array_to_buffer(data["vector_field"]) + else: + del data["vector_field"] + + if self.tool_call_id is None: + del data["tool_call_id"] + + return data + + +class StandardSessionIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "role", "type": "tag"}, + {"name": "content", "type": "text"}, + {"name": "tool_call_id", "type": "tag"}, + {"name": "timestamp", "type": "numeric"}, + {"name": "session_tag", "type": "tag"}, + ], + ) + + +class SemanticSessionIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str, vectorizer_dims: int): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "role", "type": "tag"}, + {"name": "content", "type": "text"}, + {"name": "tool_call_id", "type": "tag"}, + {"name": "timestamp", "type": "numeric"}, + {"name": "session_tag", "type": "tag"}, + { + "name": "vector_field", + "type": "vector", + "attrs": { + "dims": vectorizer_dims, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index ccfcf2e7..a0903b99 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -1,44 +1,19 @@ -from time import time from typing import Any, Dict, List, Optional, Union from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager +from redisvl.extensions.session_manager.schema import ( + ChatMessage, + SemanticSessionIndexSchema, +) from redisvl.index import SearchIndex from redisvl.query import FilterQuery, RangeQuery -from redisvl.query.filter import FilterExpression, Tag -from redisvl.redis.utils import array_to_buffer -from redisvl.schema.schema import IndexSchema +from redisvl.query.filter import Tag +from redisvl.utils.utils import validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -class SemanticSessionIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, prefix: str, vectorizer_dims: int): - - return cls( - index={"name": name, "prefix": prefix}, # type: ignore - fields=[ # type: ignore - {"name": "role", "type": "text"}, - {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "text"}, - {"name": "timestamp", "type": "numeric"}, - {"name": "session_tag", "type": "tag"}, - { - "name": "vector_field", - "type": "vector", - "attrs": { - "dims": vectorizer_dims, - "datatype": "float32", - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ], - ) - - class SemanticSessionManager(BaseSessionManager): vector_field_name: str = "vector_field" @@ -148,9 +123,11 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=True) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] - return self._format_context(hits, as_text=False) + return self._format_context(messages, as_text=False) def get_relevant( self, @@ -198,7 +175,6 @@ def get_relevant( self.content_field_name, self.timestamp_field_name, self.tool_field_name, - self.vector_field_name, ] session_filter = ( @@ -216,14 +192,14 @@ def get_relevant( return_score=True, filter_expression=session_filter, ) - hits = self._index.query(query) + messages = self._index.query(query) # if we don't find semantic matches fallback to returning recent context - if not hits and fall_back: + if not messages and fall_back: return self.get_recent(as_text=as_text, top_k=top_k, raw=raw) if raw: - return hits - return self._format_context(hits, as_text) + return messages + return self._format_context(messages, as_text) def get_recent( self, @@ -276,11 +252,13 @@ def get_recent( sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=False) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] if raw: - return hits[::-1] - return self._format_context(hits[::-1], as_text) + return messages[::-1] + return self._format_context(messages[::-1], as_text) @property def distance_threshold(self): @@ -322,26 +300,31 @@ def add_messages( session_tag (Optional[str]): Tag to be added to entries to link to a specific session. Defaults to instance uuid. """ - sep = self._index.key_separator session_tag = session_tag or self._session_tag - payloads = [] + chat_messages: List[Dict[str, Any]] = [] + for message in messages: - vector = self._vectorizer.embed(message[self.content_field_name]) - timestamp = time() - id_field = sep.join([self._session_tag, str(timestamp)]) - payload = { - self.id_field_name: id_field, - self.role_field_name: message[self.role_field_name], - self.content_field_name: message[self.content_field_name], - self.timestamp_field_name: timestamp, - self.vector_field_name: array_to_buffer(vector), - self.session_field_name: session_tag, - } + + content_vector = self._vectorizer.embed(message[self.content_field_name]) + + validate_vector_dims( + len(content_vector), + self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + ) + + chat_message = ChatMessage( + role=message[self.role_field_name], + content=message[self.content_field_name], + session_tag=session_tag, + vector_field=content_vector, + ) if self.tool_field_name in message: - payload.update({self.tool_field_name: message[self.tool_field_name]}) - payloads.append(payload) - self._index.load(data=payloads, id_field=self.id_field_name) + chat_message.tool_call_id = message[self.tool_field_name] + + chat_messages.append(chat_message.to_dict()) + + self._index.load(data=chat_messages, id_field=self.id_field_name) def add_message( self, message: Dict[str, str], session_tag: Optional[str] = None diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 0a1a7b25..37d18a3e 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -1,30 +1,15 @@ -from time import time from typing import Any, Dict, List, Optional, Union from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager +from redisvl.extensions.session_manager.schema import ( + ChatMessage, + StandardSessionIndexSchema, +) from redisvl.index import SearchIndex from redisvl.query import FilterQuery from redisvl.query.filter import Tag -from redisvl.schema.schema import IndexSchema - - -class StandardSessionIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, prefix: str): - - return cls( - index={"name": name, "prefix": prefix}, # type: ignore - fields=[ # type: ignore - {"name": "role", "type": "text"}, - {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "text"}, - {"name": "timestamp", "type": "numeric"}, - {"name": "session_tag", "type": "tag"}, - ], - ) class StandardSessionManager(BaseSessionManager): @@ -74,7 +59,7 @@ def __init__( if redis_client: self._index.set_client(redis_client) else: - self._index.connect(redis_url=redis_url) + self._index.connect(redis_url=redis_url, **connection_kwargs) self._index.create(overwrite=False) @@ -121,9 +106,11 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=True) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] - return self._format_context(hits, as_text=False) + return self._format_context(messages, as_text=False) def get_recent( self, @@ -176,11 +163,13 @@ def get_recent( sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=False) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] if raw: - return hits[::-1] - return self._format_context(hits[::-1], as_text) + return messages[::-1] + return self._format_context(messages[::-1], as_text) def store( self, prompt: str, response: str, session_tag: Optional[str] = None @@ -215,25 +204,23 @@ def add_messages( session_tag (Optional[str]): Tag to be added to entries to link to a specific session. Defaults to instance uuid. """ - sep = self._index.key_separator session_tag = session_tag or self._session_tag - payloads = [] + chat_messages: List[Dict[str, Any]] = [] + for message in messages: - timestamp = time() - id_field = sep.join([self._session_tag, str(timestamp)]) - payload = { - self.id_field_name: id_field, - self.role_field_name: message[self.role_field_name], - self.content_field_name: message[self.content_field_name], - self.timestamp_field_name: timestamp, - self.session_field_name: session_tag, - } + + chat_message = ChatMessage( + role=message[self.role_field_name], + content=message[self.content_field_name], + session_tag=session_tag, + ) if self.tool_field_name in message: - payload.update({self.tool_field_name: message[self.tool_field_name]}) + chat_message.tool_call_id = message[self.tool_field_name] + + chat_messages.append(chat_message.to_dict()) - payloads.append(payload) - self._index.load(data=payloads, id_field=self.id_field_name) + self._index.load(data=chat_messages, id_field=self.id_field_name) def add_message( self, message: Dict[str, str], session_tag: Optional[str] = None diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 96c5250e..5f5cc882 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,9 +1,22 @@ +import json from enum import Enum +from time import time from typing import Any, Dict +from uuid import uuid4 from pydantic.v1 import BaseModel +def create_uuid() -> str: + """Generate a unique indentifier to group related Redis documents.""" + return str(uuid4()) + + +def current_timestamp() -> float: + """Generate a unix epoch timestamp to assign to Redis documents.""" + return time() + + def model_to_dict(model: BaseModel) -> Dict[str, Any]: """ Custom serialization function that converts a Pydantic model to a dict, @@ -24,3 +37,23 @@ def serialize_item(item): for key, value in serialized_data.items(): serialized_data[key] = serialize_item(value) return serialized_data + + +def validate_vector_dims(v1: int, v2: int) -> None: + """Check the equality of vector dimensions.""" + if v1 != v2: + raise ValueError( + "Invalid vector dimensions! " f"Vector has dims defined as {v1}", + f"Vector field has dims defined as {v2}", + "Vector dims must be equal in order to perform similarity search.", + ) + + +def serialize(data: Dict[str, Any]) -> str: + """Serlize the input into a string.""" + return json.dumps(data) + + +def deserialize(self, data: str) -> Dict[str, Any]: + """Deserialize the input from a string.""" + return json.loads(data) diff --git a/tests/unit/test_session_schema.py b/tests/unit/test_session_schema.py new file mode 100644 index 00000000..292a0d0a --- /dev/null +++ b/tests/unit/test_session_schema.py @@ -0,0 +1,134 @@ +from uuid import uuid4 + +import pytest +from pydantic.v1 import ValidationError + +from redisvl.extensions.session_manager.schema import ChatMessage +from redisvl.redis.utils import array_to_buffer +from redisvl.utils.utils import create_uuid, current_timestamp + + +def test_chat_message_creation(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + ) + + assert chat_message._id == f"{session_tag}:{timestamp}" + assert chat_message.role == "user" + assert chat_message.content == content + assert chat_message.session_tag == session_tag + assert chat_message.timestamp == timestamp + assert chat_message.tool_call_id is None + assert chat_message.vector_field is None + + +def test_chat_message_default_id_generation(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + chat_message = ChatMessage( + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + ) + + assert chat_message._id == f"{session_tag}:{timestamp}" + + +def test_chat_message_with_tool_call_id(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + tool_call_id = create_uuid() + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + tool_call_id=tool_call_id, + ) + + assert chat_message.tool_call_id == tool_call_id + + +def test_chat_message_with_vector_field(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + vector_field = [0.1, 0.2, 0.3] + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + vector_field=vector_field, + ) + + assert chat_message.vector_field == vector_field + + +def test_chat_message_to_dict(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + vector_field = [0.1, 0.2, 0.3] + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + vector_field=vector_field, + ) + + data = chat_message.to_dict() + + assert data["_id"] == f"{session_tag}:{timestamp}" + assert data["role"] == "user" + assert data["content"] == content + assert data["session_tag"] == session_tag + assert data["timestamp"] == timestamp + assert data["vector_field"] == array_to_buffer(vector_field) + + +def test_chat_message_missing_fields(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + with pytest.raises(ValidationError): + ChatMessage( + content=content, + session_tag=session_tag, + timestamp=timestamp, + ) + + +def test_chat_message_invalid_role(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + with pytest.raises(ValidationError): + ChatMessage( + _id=f"{session_tag}:{timestamp}", + role=[1, 2, 3], # Invalid role type + content=content, + session_tag=session_tag, + timestamp=timestamp, + )