Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented appending arbitrary messages #5293

Merged
merged 6 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions langchain/memory/chat_message_histories/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from typing import List

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
Expand Down Expand Up @@ -143,13 +141,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore

return messages

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Cassandra"""

import uuid
Expand Down
17 changes: 5 additions & 12 deletions langchain/memory/chat_message_histories/cosmos_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from typing import TYPE_CHECKING, Any, List, Optional, Type

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
messages_from_dict,
messages_to_dict,
)
Expand Down Expand Up @@ -140,18 +138,13 @@ def load_messages(self) -> None:
if "messages" in item and len(item["messages"]) > 0:
self.messages = messages_from_dict(item["messages"])

def add_user_message(self, message: str) -> None:
"""Add a user message to the memory."""
self.upsert_messages(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
"""Add a AI message to the memory."""
self.upsert_messages(AIMessage(content=message))
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)
self.upsert_messages()

def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
def upsert_messages(self) -> None:
"""Update the cosmosdb item."""
if new_message:
self.messages.append(new_message)
if not self._container:
raise ValueError("Container not initialized")
self._container.upsert_item(
Expand Down
10 changes: 1 addition & 9 deletions langchain/memory/chat_message_histories/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import List

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
messages_to_dict,
Expand Down Expand Up @@ -53,13 +51,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages = messages_from_dict(items)
return messages

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in DynamoDB"""
from botocore.exceptions import ClientError

Expand Down
10 changes: 1 addition & 9 deletions langchain/memory/chat_message_histories/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from typing import List

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
messages_from_dict,
messages_to_dict,
)
Expand Down Expand Up @@ -36,13 +34,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages = messages_from_dict(items)
return messages

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in the local file"""
messages = messages_to_dict(self.messages)
messages.append(messages_to_dict([message])[0])
Expand Down
14 changes: 3 additions & 11 deletions langchain/memory/chat_message_histories/firestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from typing import TYPE_CHECKING, List, Optional

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
messages_from_dict,
messages_to_dict,
)
Expand Down Expand Up @@ -81,18 +79,12 @@ def load_messages(self) -> None:
if "messages" in data and len(data["messages"]) > 0:
self.messages = messages_from_dict(data["messages"])

def add_user_message(self, message: str) -> None:
"""Add a user message to the memory."""
self.upsert_messages(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
"""Add a AI message to the memory."""
self.upsert_messages(AIMessage(content=message))
def add_message(self, message: BaseMessage) -> None:
self.messages.append(message)
self.upsert_messages()

def upsert_messages(self, new_message: Optional[BaseMessage] = None) -> None:
"""Update the Firestore document."""
if new_message:
self.messages.append(new_message)
if not self._document:
raise ValueError("Document not initialized")
self._document.set(
Expand Down
10 changes: 3 additions & 7 deletions langchain/memory/chat_message_histories/in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,17 @@
from pydantic import BaseModel

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
)


class ChatMessageHistory(BaseChatMessageHistory, BaseModel):
messages: List[BaseMessage] = []

def add_user_message(self, message: str) -> None:
self.messages.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.messages.append(AIMessage(content=message))
def add_message(self, message: BaseMessage) -> None:
"""Add a self-created message to the store"""
self.messages.append(message)

def clear(self) -> None:
self.messages = []
20 changes: 1 addition & 19 deletions langchain/memory/chat_message_histories/momento.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
from typing import TYPE_CHECKING, Any, Optional

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
Expand Down Expand Up @@ -143,23 +141,7 @@ def messages(self) -> list[BaseMessage]: # type: ignore[override]
else:
raise Exception(f"Unexpected response: {fetch_response}")

def add_user_message(self, message: str) -> None:
"""Store a user message in the cache.

Args:
message (str): The message to store.
"""
self.__add_message(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
"""Store an AI message in the cache.

Args:
message (str): The message to store.
"""
self.__add_message(AIMessage(content=message))

def __add_message(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Store a message in the cache.

Args:
Expand Down
10 changes: 1 addition & 9 deletions langchain/memory/chat_message_histories/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from typing import List

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
Expand Down Expand Up @@ -68,13 +66,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages = messages_from_dict(items)
return messages

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in MongoDB"""
from pymongo import errors

Expand Down
10 changes: 1 addition & 9 deletions langchain/memory/chat_message_histories/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from typing import List

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
Expand Down Expand Up @@ -55,13 +53,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages = messages_from_dict(items)
return messages

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL"""
from psycopg import sql

Expand Down
10 changes: 1 addition & 9 deletions langchain/memory/chat_message_histories/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from typing import List, Optional

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
Expand Down Expand Up @@ -52,13 +50,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages = messages_from_dict(items)
return messages

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Redis"""
self.redis_client.lpush(self.key, json.dumps(_message_to_dict(message)))
if self.ttl:
Expand Down
10 changes: 1 addition & 9 deletions langchain/memory/chat_message_histories/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from sqlalchemy.orm import sessionmaker

from langchain.schema import (
AIMessage,
BaseChatMessageHistory,
BaseMessage,
HumanMessage,
_message_to_dict,
messages_from_dict,
)
Expand Down Expand Up @@ -61,13 +59,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore
messages = messages_from_dict(items)
return messages

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in db"""
with self.Session() as session:
jsonstr = json.dumps(_message_to_dict(message))
Expand Down
8 changes: 1 addition & 7 deletions langchain/memory/chat_message_histories/zep.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,7 @@ def _get_memory(self) -> Optional[Memory]:
return None
return zep_memory

def add_user_message(self, message: str) -> None:
self.append(HumanMessage(content=message))

def add_ai_message(self, message: str) -> None:
self.append(AIMessage(content=message))

def append(self, message: BaseMessage) -> None:
def add_message(self, message: BaseMessage) -> None:
"""Append the message to the Zep memory history"""
from zep_python import Memory, Message

Expand Down
21 changes: 9 additions & 12 deletions langchain/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,32 +234,29 @@ def messages(self):
messages = json.loads(f.read())
return messages_from_dict(messages)

def add_user_message(self, message: str):
message_ = HumanMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
def add_message(self, message: BaseMessage) -> None:
messages = self.messages.append(_message_to_dict(message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)

def add_ai_message(self, message: str):
message_ = AIMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)


def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
"""

messages: List[BaseMessage]

@abstractmethod
def add_user_message(self, message: str) -> None:
"""Add a user message to the store"""
self.add_message(HumanMessage(content=message))

@abstractmethod
def add_ai_message(self, message: str) -> None:
"""Add an AI message to the store"""
self.add_message(AIMessage(content=message))

def add_message(self, message: BaseMessage) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me. Looks more appropriate than add_ai_message or add_human_message for the lowest building block of chat history.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

went a step further now, got rid of all implementations of add_xx_message, instead using add_message in each subclass, and then noticed most had another method (usually append) with the same footprint, so got rid of those as well, much cleaner IMHO!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think this is super backwards compatible, can we raise NotImplementedError rather than have as abstract? that way if people subclassed it it doesnt break

"""Add a self-created message to the store"""
raise NotImplementedError

@abstractmethod
def clear(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/memory/chat_message_histories/test_zep.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_add_ai_message(mocker: MockerFixture, zep_chat: ZepChatMessageHistory)

@pytest.mark.requires("zep_python")
def test_append(mocker: MockerFixture, zep_chat: ZepChatMessageHistory) -> None:
zep_chat.append(AIMessage(content="test message"))
zep_chat.add_message(AIMessage(content="test message"))
zep_chat.zep_client.add_memory.assert_called_once() # type: ignore


Expand Down