From ac96375ac35da2842d36a07a490d8a8b499a0919 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 31 Jul 2024 12:22:23 -0400 Subject: [PATCH 1/5] finish up chatmessage update and add tests --- .../session_manager/base_session.py | 56 ++++---- redisvl/extensions/session_manager/schema.py | 49 +++++++ .../session_manager/semantic_session.py | 62 ++++---- .../session_manager/standard_session.py | 48 ++++--- redisvl/utils/utils.py | 10 ++ tests/unit/test_session_schema.py | 134 ++++++++++++++++++ 6 files changed, 274 insertions(+), 85 deletions(-) create mode 100644 redisvl/extensions/session_manager/schema.py create mode 100644 tests/unit/test_session_schema.py diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index 97c4a9f1..ea6264e3 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -1,9 +1,7 @@ from typing import Any, Dict, List, Optional, Union -from uuid import uuid4 -from redis import Redis - -from redisvl.query.filter import FilterExpression +from redisvl.extensions.session_manager.schema import ChatMessage +from redisvl.utils.utils import create_uuid class BaseSessionManager: @@ -32,7 +30,7 @@ def __init__( session. Defaults to instance uuid. """ self._name = name - self._session_tag = session_tag or uuid4().hex + self._session_tag = session_tag or create_uuid() def clear(self) -> None: """Clears the chat session history.""" @@ -85,14 +83,13 @@ def get_recent( raise NotImplementedError def _format_context( - self, hits: List[Dict[str, Any]], as_text: bool + self, messages: List[Dict[str, Any]], as_text: bool ) -> Union[List[str], List[Dict[str, str]]]: """Extracts the prompt and response fields from the Redis hashes and formats them as either flat dictionaries or strings. Args: - hits (List): The hashes containing prompt & response pairs from - recent conversation history. + messages (List[Dict[str, Any]]): The messages from the session index. as_text (bool): Whether to return the conversation as a single string, or list of alternating prompts and responses. @@ -100,29 +97,26 @@ def _format_context( Union[str, List[str]]: A single string transcription of the session or list of strings if as_text is false. """ - if as_text: - text_statements = [] - for hit in hits: - text_statements.append(hit[self.content_field_name]) - return text_statements - else: - statements = [] - for hit in hits: - statements.append( - { - self.role_field_name: hit[self.role_field_name], - self.content_field_name: hit[self.content_field_name], - } - ) - if ( - hasattr(hit, self.tool_field_name) - or isinstance(hit, dict) - and self.tool_field_name in hit - ): - statements[-1].update( - {self.tool_field_name: hit[self.tool_field_name]} - ) - return statements + context = [] + + for message in messages: + + chat_message = ChatMessage(**message) + + if as_text: + context.append(chat_message.content) + else: + chat_message = ChatMessage(**message) + chat_message_dict = { + self.role_field_name: chat_message.role, + self.content_field_name: chat_message.content, + } + if chat_message.tool_call_id is not None: + chat_message_dict[self.tool_field_name] = chat_message.tool_call_id + + context.append(chat_message_dict) # type: ignore + + return context def store( self, prompt: str, response: str, session_tag: Optional[str] = None diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py new file mode 100644 index 00000000..1b8a343d --- /dev/null +++ b/redisvl/extensions/session_manager/schema.py @@ -0,0 +1,49 @@ +from typing import Dict, List, Optional + +from pydantic.v1 import BaseModel, Field, root_validator + +from redisvl.redis.utils import array_to_buffer +from redisvl.utils.utils import current_timestamp + + +class ChatMessage(BaseModel): + """A single chat message exchanged between a user and an LLM.""" + + _id: Optional[str] = Field(default=None) + """A unique identifier for the message.""" + role: str # TODO -- do we enumify this? + """The role of the message sender (e.g., 'user' or 'llm').""" + content: str + """The content of the message.""" + session_tag: str + """Tag associated with the current session.""" + timestamp: float = Field(default_factory=current_timestamp) + """The time the message was sent, in UTC, rounded to milliseconds.""" + tool_call_id: Optional[str] = Field(default=None) + """An optional identifier for a tool call associated with the message.""" + vector_field: Optional[List[float]] = Field(default=None) + """The vector representation of the message content.""" + + class Config: + arbitrary_types_allowed = True + + @root_validator(pre=False) + @classmethod + def generate_id(cls, values): + if "_id" not in values: + values["_id"] = f'{values["session_tag"]}:{values["timestamp"]}' + return values + + def to_dict(self) -> Dict: + data = self.dict() + + # handle optional fields + if data["vector_field"] is not None: + data["vector_field"] = array_to_buffer(data["vector_field"]) + else: + del data["vector_field"] + + if self.tool_call_id is None: + del data["tool_call_id"] + + return data diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index ccfcf2e7..b2ed9933 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -1,13 +1,12 @@ -from time import time from typing import Any, Dict, List, Optional, Union from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager +from redisvl.extensions.session_manager.schema import ChatMessage from redisvl.index import SearchIndex from redisvl.query import FilterQuery, RangeQuery -from redisvl.query.filter import FilterExpression, Tag -from redisvl.redis.utils import array_to_buffer +from redisvl.query.filter import Tag from redisvl.schema.schema import IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -20,9 +19,9 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int): return cls( index={"name": name, "prefix": prefix}, # type: ignore fields=[ # type: ignore - {"name": "role", "type": "text"}, + {"name": "role", "type": "tag"}, {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "text"}, + {"name": "tool_call_id", "type": "tag"}, {"name": "timestamp", "type": "numeric"}, {"name": "session_tag", "type": "tag"}, { @@ -148,9 +147,11 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=True) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] - return self._format_context(hits, as_text=False) + return self._format_context(messages, as_text=False) def get_relevant( self, @@ -198,7 +199,6 @@ def get_relevant( self.content_field_name, self.timestamp_field_name, self.tool_field_name, - self.vector_field_name, ] session_filter = ( @@ -216,14 +216,14 @@ def get_relevant( return_score=True, filter_expression=session_filter, ) - hits = self._index.query(query) + messages = self._index.query(query) # if we don't find semantic matches fallback to returning recent context - if not hits and fall_back: + if not messages and fall_back: return self.get_recent(as_text=as_text, top_k=top_k, raw=raw) if raw: - return hits - return self._format_context(hits, as_text) + return messages + return self._format_context(messages, as_text) def get_recent( self, @@ -276,11 +276,13 @@ def get_recent( sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=False) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] if raw: - return hits[::-1] - return self._format_context(hits[::-1], as_text) + return messages[::-1] + return self._format_context(messages[::-1], as_text) @property def distance_threshold(self): @@ -322,26 +324,24 @@ def add_messages( session_tag (Optional[str]): Tag to be added to entries to link to a specific session. Defaults to instance uuid. """ - sep = self._index.key_separator session_tag = session_tag or self._session_tag - payloads = [] + chat_messages: List[Dict[str, Any]] = [] + for message in messages: - vector = self._vectorizer.embed(message[self.content_field_name]) - timestamp = time() - id_field = sep.join([self._session_tag, str(timestamp)]) - payload = { - self.id_field_name: id_field, - self.role_field_name: message[self.role_field_name], - self.content_field_name: message[self.content_field_name], - self.timestamp_field_name: timestamp, - self.vector_field_name: array_to_buffer(vector), - self.session_field_name: session_tag, - } + + chat_message = ChatMessage( + role=message[self.role_field_name], + content=message[self.content_field_name], + session_tag=session_tag, + vector_field=self._vectorizer.embed(message[self.content_field_name]), + ) if self.tool_field_name in message: - payload.update({self.tool_field_name: message[self.tool_field_name]}) - payloads.append(payload) - self._index.load(data=payloads, id_field=self.id_field_name) + chat_message.tool_call_id = message[self.tool_field_name] + + chat_messages.append(chat_message.to_dict()) + + self._index.load(data=chat_messages, id_field=self.id_field_name) def add_message( self, message: Dict[str, str], session_tag: Optional[str] = None diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 0a1a7b25..54725e7b 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -1,9 +1,9 @@ -from time import time from typing import Any, Dict, List, Optional, Union from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager +from redisvl.extensions.session_manager.schema import ChatMessage from redisvl.index import SearchIndex from redisvl.query import FilterQuery from redisvl.query.filter import Tag @@ -18,9 +18,9 @@ def from_params(cls, name: str, prefix: str): return cls( index={"name": name, "prefix": prefix}, # type: ignore fields=[ # type: ignore - {"name": "role", "type": "text"}, + {"name": "role", "type": "tag"}, {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "text"}, + {"name": "tool_call_id", "type": "tag"}, {"name": "timestamp", "type": "numeric"}, {"name": "session_tag", "type": "tag"}, ], @@ -74,7 +74,7 @@ def __init__( if redis_client: self._index.set_client(redis_client) else: - self._index.connect(redis_url=redis_url) + self._index.connect(redis_url=redis_url, **connection_kwargs) self._index.create(overwrite=False) @@ -121,9 +121,11 @@ def messages(self) -> Union[List[str], List[Dict[str, str]]]: sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=True) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] - return self._format_context(hits, as_text=False) + return self._format_context(messages, as_text=False) def get_recent( self, @@ -176,11 +178,13 @@ def get_recent( sorted_query = query.query sorted_query.sort_by(self.timestamp_field_name, asc=False) - hits = self._index.search(sorted_query, query.params).docs + messages = [ + doc.__dict__ for doc in self._index.search(sorted_query, query.params).docs + ] if raw: - return hits[::-1] - return self._format_context(hits[::-1], as_text) + return messages[::-1] + return self._format_context(messages[::-1], as_text) def store( self, prompt: str, response: str, session_tag: Optional[str] = None @@ -215,25 +219,23 @@ def add_messages( session_tag (Optional[str]): Tag to be added to entries to link to a specific session. Defaults to instance uuid. """ - sep = self._index.key_separator session_tag = session_tag or self._session_tag - payloads = [] + chat_messages: List[Dict[str, Any]] = [] + for message in messages: - timestamp = time() - id_field = sep.join([self._session_tag, str(timestamp)]) - payload = { - self.id_field_name: id_field, - self.role_field_name: message[self.role_field_name], - self.content_field_name: message[self.content_field_name], - self.timestamp_field_name: timestamp, - self.session_field_name: session_tag, - } + + chat_message = ChatMessage( + role=message[self.role_field_name], + content=message[self.content_field_name], + session_tag=session_tag, + ) if self.tool_field_name in message: - payload.update({self.tool_field_name: message[self.tool_field_name]}) + chat_message.tool_call_id = message[self.tool_field_name] + + chat_messages.append(chat_message.to_dict()) - payloads.append(payload) - self._index.load(data=payloads, id_field=self.id_field_name) + self._index.load(data=chat_messages, id_field=self.id_field_name) def add_message( self, message: Dict[str, str], session_tag: Optional[str] = None diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 96c5250e..f5877030 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,9 +1,19 @@ from enum import Enum +from time import time from typing import Any, Dict +from uuid import uuid4 from pydantic.v1 import BaseModel +def create_uuid() -> str: + return str(uuid4()) + + +def current_timestamp(): + return time() + + def model_to_dict(model: BaseModel) -> Dict[str, Any]: """ Custom serialization function that converts a Pydantic model to a dict, diff --git a/tests/unit/test_session_schema.py b/tests/unit/test_session_schema.py new file mode 100644 index 00000000..292a0d0a --- /dev/null +++ b/tests/unit/test_session_schema.py @@ -0,0 +1,134 @@ +from uuid import uuid4 + +import pytest +from pydantic.v1 import ValidationError + +from redisvl.extensions.session_manager.schema import ChatMessage +from redisvl.redis.utils import array_to_buffer +from redisvl.utils.utils import create_uuid, current_timestamp + + +def test_chat_message_creation(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + ) + + assert chat_message._id == f"{session_tag}:{timestamp}" + assert chat_message.role == "user" + assert chat_message.content == content + assert chat_message.session_tag == session_tag + assert chat_message.timestamp == timestamp + assert chat_message.tool_call_id is None + assert chat_message.vector_field is None + + +def test_chat_message_default_id_generation(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + chat_message = ChatMessage( + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + ) + + assert chat_message._id == f"{session_tag}:{timestamp}" + + +def test_chat_message_with_tool_call_id(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + tool_call_id = create_uuid() + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + tool_call_id=tool_call_id, + ) + + assert chat_message.tool_call_id == tool_call_id + + +def test_chat_message_with_vector_field(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + vector_field = [0.1, 0.2, 0.3] + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + vector_field=vector_field, + ) + + assert chat_message.vector_field == vector_field + + +def test_chat_message_to_dict(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + vector_field = [0.1, 0.2, 0.3] + + chat_message = ChatMessage( + _id=f"{session_tag}:{timestamp}", + role="user", + content=content, + session_tag=session_tag, + timestamp=timestamp, + vector_field=vector_field, + ) + + data = chat_message.to_dict() + + assert data["_id"] == f"{session_tag}:{timestamp}" + assert data["role"] == "user" + assert data["content"] == content + assert data["session_tag"] == session_tag + assert data["timestamp"] == timestamp + assert data["vector_field"] == array_to_buffer(vector_field) + + +def test_chat_message_missing_fields(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + with pytest.raises(ValidationError): + ChatMessage( + content=content, + session_tag=session_tag, + timestamp=timestamp, + ) + + +def test_chat_message_invalid_role(): + session_tag = create_uuid() + timestamp = current_timestamp() + content = "Hello, world!" + + with pytest.raises(ValidationError): + ChatMessage( + _id=f"{session_tag}:{timestamp}", + role=[1, 2, 3], # Invalid role type + content=content, + session_tag=session_tag, + timestamp=timestamp, + ) From 8572a496e8dd228d464be8ff7c26678bbfd2c304 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 31 Jul 2024 13:01:34 -0400 Subject: [PATCH 2/5] move schema --- redisvl/extensions/router/schema.py | 35 +++++++++++++++ redisvl/extensions/router/semantic.py | 35 +-------------- .../session_manager/base_session.py | 2 +- redisvl/extensions/session_manager/schema.py | 45 +++++++++++++++++++ .../session_manager/semantic_session.py | 33 ++------------ .../session_manager/standard_session.py | 23 ++-------- 6 files changed, 90 insertions(+), 83 deletions(-) diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index c01f9254..11b88dc6 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -3,6 +3,8 @@ from pydantic.v1 import BaseModel, Field, validator +from redisvl.schema import IndexInfo, IndexSchema + class Route(BaseModel): """Model representing a routing path with associated metadata and thresholds.""" @@ -80,3 +82,36 @@ def distance_threshold_must_be_valid(cls, v): if v <= 0 or v > 1: raise ValueError("distance_threshold must be between 0 and 1") return v + + +class SemanticRouterIndexSchema(IndexSchema): + """Customized index schema for SemanticRouter.""" + + @classmethod + def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": + """Create an index schema based on router name and vector dimensions. + + Args: + name (str): The name of the index. + vector_dims (int): The dimensions of the vectors. + + Returns: + SemanticRouterIndexSchema: The constructed index schema. + """ + return cls( + index=IndexInfo(name=name, prefix=name), + fields=[ # type: ignore + {"name": "route_name", "type": "tag"}, + {"name": "reference", "type": "text"}, + { + "name": "vector", + "type": "vector", + "attrs": { + "algorithm": "flat", + "dims": vector_dims, + "distance_metric": "cosine", + "datatype": "float32", + }, + }, + ], + ) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 1c34202a..bab69578 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -13,11 +13,11 @@ Route, RouteMatch, RoutingConfig, + SemanticRouterIndexSchema, ) from redisvl.index import SearchIndex from redisvl.query import RangeQuery from redisvl.redis.utils import convert_bytes, hashify, make_dict -from redisvl.schema import IndexInfo, IndexSchema from redisvl.utils.log import get_logger from redisvl.utils.utils import model_to_dict from redisvl.utils.vectorize import ( @@ -29,39 +29,6 @@ logger = get_logger(__name__) -class SemanticRouterIndexSchema(IndexSchema): - """Customized index schema for SemanticRouter.""" - - @classmethod - def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema": - """Create an index schema based on router name and vector dimensions. - - Args: - name (str): The name of the index. - vector_dims (int): The dimensions of the vectors. - - Returns: - SemanticRouterIndexSchema: The constructed index schema. - """ - return cls( - index=IndexInfo(name=name, prefix=name), - fields=[ # type: ignore - {"name": "route_name", "type": "tag"}, - {"name": "reference", "type": "text"}, - { - "name": "vector", - "type": "vector", - "attrs": { - "algorithm": "flat", - "dims": vector_dims, - "distance_metric": "cosine", - "datatype": "float32", - }, - }, - ], - ) - - class SemanticRouter(BaseModel): """Semantic Router for managing and querying route vectors.""" diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index ea6264e3..fb172c44 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -114,7 +114,7 @@ def _format_context( if chat_message.tool_call_id is not None: chat_message_dict[self.tool_field_name] = chat_message.tool_call_id - context.append(chat_message_dict) # type: ignore + context.append(chat_message_dict) # type: ignore return context diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index 1b8a343d..4394d4cb 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -3,6 +3,7 @@ from pydantic.v1 import BaseModel, Field, root_validator from redisvl.redis.utils import array_to_buffer +from redisvl.schema import IndexSchema from redisvl.utils.utils import current_timestamp @@ -47,3 +48,47 @@ def to_dict(self) -> Dict: del data["tool_call_id"] return data + + +class StandardSessionIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "role", "type": "tag"}, + {"name": "content", "type": "text"}, + {"name": "tool_call_id", "type": "tag"}, + {"name": "timestamp", "type": "numeric"}, + {"name": "session_tag", "type": "tag"}, + ], + ) + + +class SemanticSessionIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str, vectorizer_dims: int): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "role", "type": "tag"}, + {"name": "content", "type": "text"}, + {"name": "tool_call_id", "type": "tag"}, + {"name": "timestamp", "type": "numeric"}, + {"name": "session_tag", "type": "tag"}, + { + "name": "vector_field", + "type": "vector", + "attrs": { + "dims": vectorizer_dims, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index b2ed9933..2b530a30 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -3,41 +3,16 @@ from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager -from redisvl.extensions.session_manager.schema import ChatMessage +from redisvl.extensions.session_manager.schema import ( + ChatMessage, + SemanticSessionIndexSchema, +) from redisvl.index import SearchIndex from redisvl.query import FilterQuery, RangeQuery from redisvl.query.filter import Tag -from redisvl.schema.schema import IndexSchema from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -class SemanticSessionIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, prefix: str, vectorizer_dims: int): - - return cls( - index={"name": name, "prefix": prefix}, # type: ignore - fields=[ # type: ignore - {"name": "role", "type": "tag"}, - {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "tag"}, - {"name": "timestamp", "type": "numeric"}, - {"name": "session_tag", "type": "tag"}, - { - "name": "vector_field", - "type": "vector", - "attrs": { - "dims": vectorizer_dims, - "datatype": "float32", - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ], - ) - - class SemanticSessionManager(BaseSessionManager): vector_field_name: str = "vector_field" diff --git a/redisvl/extensions/session_manager/standard_session.py b/redisvl/extensions/session_manager/standard_session.py index 54725e7b..37d18a3e 100644 --- a/redisvl/extensions/session_manager/standard_session.py +++ b/redisvl/extensions/session_manager/standard_session.py @@ -3,28 +3,13 @@ from redis import Redis from redisvl.extensions.session_manager import BaseSessionManager -from redisvl.extensions.session_manager.schema import ChatMessage +from redisvl.extensions.session_manager.schema import ( + ChatMessage, + StandardSessionIndexSchema, +) from redisvl.index import SearchIndex from redisvl.query import FilterQuery from redisvl.query.filter import Tag -from redisvl.schema.schema import IndexSchema - - -class StandardSessionIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, prefix: str): - - return cls( - index={"name": name, "prefix": prefix}, # type: ignore - fields=[ # type: ignore - {"name": "role", "type": "tag"}, - {"name": "content", "type": "text"}, - {"name": "tool_call_id", "type": "tag"}, - {"name": "timestamp", "type": "numeric"}, - {"name": "session_tag", "type": "tag"}, - ], - ) class StandardSessionManager(BaseSessionManager): From dc447b9d03917411f7ae8df1566884fbd0b87b15 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 31 Jul 2024 13:53:06 -0400 Subject: [PATCH 3/5] clean up chat message modeling --- .../session_manager/base_session.py | 1 - .../session_manager/semantic_session.py | 10 ++++++++- redisvl/utils/utils.py | 22 +++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/redisvl/extensions/session_manager/base_session.py b/redisvl/extensions/session_manager/base_session.py index fb172c44..c2cddb2b 100644 --- a/redisvl/extensions/session_manager/base_session.py +++ b/redisvl/extensions/session_manager/base_session.py @@ -106,7 +106,6 @@ def _format_context( if as_text: context.append(chat_message.content) else: - chat_message = ChatMessage(**message) chat_message_dict = { self.role_field_name: chat_message.role, self.content_field_name: chat_message.content, diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 2b530a30..1cc5d129 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -10,6 +10,7 @@ from redisvl.index import SearchIndex from redisvl.query import FilterQuery, RangeQuery from redisvl.query.filter import Tag +from redisvl.utils.utils import validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer @@ -304,11 +305,18 @@ def add_messages( for message in messages: + content_vector = self._vectorizer.embed(message[self.content_field_name]) + + validate_vector_dims( + len(content_vector), + self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + ) + chat_message = ChatMessage( role=message[self.role_field_name], content=message[self.content_field_name], session_tag=session_tag, - vector_field=self._vectorizer.embed(message[self.content_field_name]), + vector_field=content_vector, ) if self.tool_field_name in message: diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index f5877030..aa6da58d 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -2,6 +2,7 @@ from time import time from typing import Any, Dict from uuid import uuid4 +import json from pydantic.v1 import BaseModel @@ -34,3 +35,24 @@ def serialize_item(item): for key, value in serialized_data.items(): serialized_data[key] = serialize_item(value) return serialized_data + + +def validate_vector_dims(v1: int, v2: int) -> None: + """Check the equality of vector dimensions.""" + if v1 != v2: + raise ValueError( + "Invalid vector dimensions! " + f"Vector has dims defined as {v1}", + f"Vector field has dims defined as {v2}", + "Vector dims must be equal in order to perform similarity search." + ) + + +def serialize(data: Dict[str, Any]) -> str: + """Serlize the input into a string.""" + return json.dumps(data) + + +def deserialize(self, data: str) -> Dict[str, Any]: + """Deserialize the input from a string.""" + return json.loads(data) \ No newline at end of file From 093dd2a95d1866d4f69bbb602726845ffb1eca05 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 31 Jul 2024 13:58:02 -0400 Subject: [PATCH 4/5] final cleanup --- redisvl/extensions/session_manager/semantic_session.py | 2 +- redisvl/utils/utils.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 1cc5d129..a0903b99 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -309,7 +309,7 @@ def add_messages( validate_vector_dims( len(content_vector), - self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore ) chat_message = ChatMessage( diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index aa6da58d..6c039733 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,8 +1,8 @@ +import json from enum import Enum from time import time from typing import Any, Dict from uuid import uuid4 -import json from pydantic.v1 import BaseModel @@ -41,10 +41,9 @@ def validate_vector_dims(v1: int, v2: int) -> None: """Check the equality of vector dimensions.""" if v1 != v2: raise ValueError( - "Invalid vector dimensions! " - f"Vector has dims defined as {v1}", + "Invalid vector dimensions! " f"Vector has dims defined as {v1}", f"Vector field has dims defined as {v2}", - "Vector dims must be equal in order to perform similarity search." + "Vector dims must be equal in order to perform similarity search.", ) @@ -55,4 +54,4 @@ def serialize(data: Dict[str, Any]) -> str: def deserialize(self, data: str) -> Dict[str, Any]: """Deserialize the input from a string.""" - return json.loads(data) \ No newline at end of file + return json.loads(data) From 0e5bc6b2249476e1f94d1c5b1546ff90a1442aad Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Wed, 31 Jul 2024 13:12:17 -0700 Subject: [PATCH 5/5] adds doc strings to util functions --- redisvl/utils/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 6c039733..5f5cc882 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -8,10 +8,12 @@ def create_uuid() -> str: + """Generate a unique indentifier to group related Redis documents.""" return str(uuid4()) -def current_timestamp(): +def current_timestamp() -> float: + """Generate a unix epoch timestamp to assign to Redis documents.""" return time()