-
Notifications
You must be signed in to change notification settings - Fork 15.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
358 additions
and
0 deletions.
There are no files selected for viewing
61 changes: 61 additions & 0 deletions
61
docs/docs/integrations/memory/falkordb_chat_message_history.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# FalkorDB\n", | ||
"\n", | ||
"<a href='https://docs.falkordb.com/' target='_blank'>FalkorDB</a> is an open-source graph database management system, renowned for its efficient management of highly connected data. Unlike traditional databases that store data in tables, FalkorDB uses a graph structure with nodes, edges, and properties to represent and store data. This design allows for high-performance queries on complex data relationships.\n", | ||
"\n", | ||
"This notebook goes over how to use `FalkorDB` to store chat message history\n", | ||
"\n", | ||
"**NOTE**: You can use FalkorDB locally or use FalkorDB Cloud. <a href='https://docs.falkordb.com/' target='blank'>See installation instructions</a>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# For this example notebook we will be using FalkorDB locally\n", | ||
"host = \"localhost\"\n", | ||
"port = 6379" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.chat_message_histories.falkordb import (\n", | ||
" FalkorDBChatMessageHistory,\n", | ||
")\n", | ||
"\n", | ||
"history = FalkorDBChatMessageHistory(host=host, port=port, session_id=\"session_id_1\")\n", | ||
"\n", | ||
"history.add_user_message(\"hi!\")\n", | ||
"\n", | ||
"history.add_ai_message(\"whats up?\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"history.messages" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
220 changes: 220 additions & 0 deletions
220
libs/community/langchain_community/chat_message_histories/falkordb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import os | ||
from typing import List, Optional, Union | ||
|
||
from langchain_core.chat_history import BaseChatMessageHistory | ||
from langchain_core.messages import ( | ||
AIMessage, | ||
BaseMessage, | ||
HumanMessage, | ||
) | ||
|
||
from langchain_community.graphs import FalkorDBGraph | ||
|
||
|
||
class FalkorDBChatMessageHistory(BaseChatMessageHistory): | ||
"""Chat message history stored in a Falkor database. | ||
This class handles storing and retrieving chat messages in a FalkorDB database. | ||
It creates a session and stores messages in a message chain, maintaining a link | ||
between subsequent messages. | ||
Args: | ||
session_id (Union[str, int]): The session ID for storing and retrieving messages | ||
also the name of the database. | ||
username (Optional[str]): Username for authenticating with FalkorDB. | ||
password (Optional[str]): Password for authenticating with FalkorDB. | ||
host (str): Host where the FalkorDB is running. Defaults to 'localhost'. | ||
port (int): Port number where the FalkorDB is running. Defaults to 6379. | ||
node_label (str): Label for the session node | ||
in the graph. Defaults to "Session". | ||
window (int): The number of messages to retrieve when querying | ||
the history. Defaults to 3. | ||
ssl (bool): Whether to use SSL for connecting | ||
to the database. Defaults to False. | ||
graph (Optional[FalkorDBGraph]): Optionally provide an existing | ||
FalkorDBGraph object for connecting. | ||
Example: | ||
.. code-block:: python | ||
from langchain_community.chat_message_histories import ( | ||
FalkorDBChatMessageHistory | ||
) | ||
history = FalkorDBChatMessageHistory( | ||
session_id="1234", | ||
host="localhost", | ||
port=6379, | ||
) | ||
history.add_message(HumanMessage(content="Hello!")) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
session_id: Union[str, int], | ||
username: Optional[str] = None, | ||
password: Optional[str] = None, | ||
host: str = "localhost", | ||
port: int = 6379, | ||
node_label: str = "Session", | ||
window: int = 3, | ||
ssl: bool = False, | ||
*, | ||
graph: Optional[FalkorDBGraph] = None, | ||
) -> None: | ||
""" | ||
Initialize the FalkorDBChatMessageHistory | ||
class with the session and connection details. | ||
""" | ||
try: | ||
import falkordb | ||
except ImportError: | ||
raise ImportError( | ||
"Could not import falkordb python package." | ||
"Please install it with `pip install falkordb`." | ||
) | ||
|
||
if not session_id: | ||
raise ValueError("Please ensure that the session_id parameter is provided.") | ||
|
||
if graph: | ||
self._database = graph._graph | ||
self._driver = graph._driver | ||
else: | ||
self._host = host | ||
self._port = port | ||
self._username = username or os.environ.get("FALKORDB_USERNAME") | ||
self._password = password or os.environ.get("FALKORDB_PASSWORD") | ||
self._ssl = ssl | ||
|
||
try: | ||
self._driver = falkordb.FalkorDB( | ||
host=self._host, | ||
port=self._port, | ||
username=self._username, | ||
password=self._password, | ||
ssl=self._ssl, | ||
) | ||
except Exception as e: | ||
raise ValueError( | ||
f"Error: {e}" | ||
"Could not connect to FalkorDB database. " | ||
"Please ensure that the host, port, username, and password are correct." | ||
Check failure on line 101 in libs/community/langchain_community/chat_message_histories/falkordb.py GitHub Actions / cd libs/community / make lint #3.13Ruff (E501)
|
||
) | ||
|
||
self._database = self._driver.select_graph(session_id) | ||
self._session_id = session_id | ||
self._node_label = node_label | ||
self._window = window | ||
|
||
self._database.query( | ||
f"MERGE (s:{self._node_label} {{id:$session_id}})", | ||
{"session_id": self._session_id}, | ||
) | ||
|
||
try: | ||
self._database.create_node_vector_index( | ||
f"{self._node_label}", "id", dim=5, similarity_function="cosine" | ||
) | ||
except Exception as e: | ||
if "already indexed" in str(e): | ||
raise ValueError(f"{self._node_label} has already been indexed") | ||
|
||
def _process_records(self, records: list) -> List[BaseMessage]: | ||
"""Process the records from FalkorDB and convert them into BaseMessage objects. | ||
Args: | ||
records (list): The raw records fetched from the FalkorDB query. | ||
Returns: | ||
List[BaseMessage]: A list of `BaseMessage` objects. | ||
""" | ||
# Explicitly set messages as a list of BaseMessage | ||
messages: List[BaseMessage] = [] | ||
|
||
for record in records: | ||
content = record[0].get("data", {}).get("content", "") | ||
message_type = record[0].get("type", "").lower() | ||
|
||
# Append the correct message type to the list | ||
if message_type == "human": | ||
messages.append( | ||
HumanMessage( | ||
content=content, additional_kwargs={}, response_metadata={} | ||
) | ||
) | ||
elif message_type == "ai": | ||
messages.append( | ||
AIMessage( | ||
content=content, additional_kwargs={}, response_metadata={} | ||
) | ||
) | ||
else: | ||
raise ValueError(f"Unknown message type: {message_type}") | ||
|
||
return messages | ||
|
||
@property | ||
def messages(self) -> List[BaseMessage]: | ||
"""Retrieve the messages from FalkorDB for the session. | ||
Returns: | ||
List[BaseMessage]: A list of messages in the current session. | ||
""" | ||
query = ( | ||
f"MATCH (s:{self._node_label})-[:LAST_MESSAGE]->(last_message) " | ||
"MATCH p=(last_message)<-[:NEXT*0.." | ||
f"{self._window*2}]-() WITH p, length(p) AS length " | ||
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node " | ||
"RETURN {data:{content: node.content}, type:node.type} AS result" | ||
) | ||
|
||
records = self._database.query(query).result_set | ||
|
||
messages = self._process_records(records) | ||
return messages | ||
|
||
@messages.setter | ||
def messages(self, messages: List[BaseMessage]) -> None: | ||
"""Block direct assignment to 'messages' to prevent misuse.""" | ||
raise NotImplementedError( | ||
"Direct assignment to 'messages' is not allowed." | ||
" Use the 'add_message' method instead." | ||
) | ||
|
||
def add_message(self, message: BaseMessage) -> None: | ||
"""Append a message to the session in FalkorDB. | ||
Args: | ||
message (BaseMessage): The message object to add to the session. | ||
""" | ||
create_query = ( | ||
f"MATCH (s:{self._node_label}) " | ||
"CREATE (new:Message {type: $type, content: $content}) " | ||
"WITH s, new " | ||
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message:Message) " | ||
"FOREACH (_ IN CASE WHEN last_message IS NULL THEN [] ELSE [1] END | " | ||
" MERGE (last_message)-[:NEXT]->(new)) " | ||
"MERGE (s)-[:LAST_MESSAGE]->(new) " | ||
) | ||
|
||
self._database.query( | ||
create_query, | ||
{ | ||
"type": message.type, | ||
"content": message.content, | ||
}, | ||
) | ||
|
||
def clear(self) -> None: | ||
"""Clear all messages from the session in FalkorDB. | ||
Deletes all messages linked to the session and resets the message history. | ||
Raises: | ||
ValueError: If there is an issue with the query or the session. | ||
""" | ||
query = ( | ||
f"MATCH (s:{self._node_label})-[:LAST_MESSAGE|NEXT*0..]->(m:Message) " | ||
"WITH m DELETE m" | ||
) | ||
self._database.query(query) |
77 changes: 77 additions & 0 deletions
77
...nity/tests/integration_tests/chat_message_histories/test_falkordb_chat_message_history.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
""" | ||
Integration tests for FalkorDB Chat History/Memory functionality. | ||
Note: | ||
These tests are conducted using a local FalkorDB instance but can also | ||
be run against a Cloud FalkorDB instance. Ensure that appropriate host,port | ||
cusername, and password configurations are set up | ||
before running the tests. | ||
Test Cases: | ||
1. test_add_messages: Test basic functionality of adding and retrieving | ||
chat messages from FalkorDB. | ||
2. test_add_messages_graph_object: Test chat message functionality | ||
when passing the FalkorDB driver through a graph object. | ||
""" | ||
|
||
from langchain_core.messages import AIMessage, HumanMessage | ||
|
||
from langchain_community.chat_message_histories.falkordb import ( | ||
FalkorDBChatMessageHistory, | ||
) | ||
from langchain_community.graphs import FalkorDBGraph | ||
|
||
|
||
def test_add_messages() -> None: | ||
"""Basic testing: add messages to the FalkorDBChatMessageHistory.""" | ||
message_store = FalkorDBChatMessageHistory("500daysofSadiya") | ||
message_store.clear() | ||
assert len(message_store.messages) == 0 | ||
message_store.add_user_message("Hello! Language Chain!") | ||
message_store.add_ai_message("Hi Guys!") | ||
|
||
# create another message store to check if the messages are stored correctly | ||
message_store_another = FalkorDBChatMessageHistory("Shebrokemyheart") | ||
message_store_another.clear() | ||
assert len(message_store_another.messages) == 0 | ||
message_store_another.add_user_message("Hello! Bot!") | ||
message_store_another.add_ai_message("Hi there!") | ||
message_store_another.add_user_message("How's this pr going?") | ||
|
||
# Now check if the messages are stored in the database correctly | ||
assert len(message_store.messages) == 2 | ||
assert isinstance(message_store.messages[0], HumanMessage) | ||
assert isinstance(message_store.messages[1], AIMessage) | ||
assert message_store.messages[0].content == "Hello! Language Chain!" | ||
assert message_store.messages[1].content == "Hi Guys!" | ||
|
||
assert len(message_store_another.messages) == 3 | ||
assert isinstance(message_store_another.messages[0], HumanMessage) | ||
assert isinstance(message_store_another.messages[1], AIMessage) | ||
assert isinstance(message_store_another.messages[2], HumanMessage) | ||
assert message_store_another.messages[0].content == "Hello! Bot!" | ||
assert message_store_another.messages[1].content == "Hi there!" | ||
assert message_store_another.messages[2].content == "How's this pr going?" | ||
|
||
# Now clear the first history | ||
message_store.clear() | ||
assert len(message_store.messages) == 0 | ||
assert len(message_store_another.messages) == 3 | ||
message_store_another.clear() | ||
assert len(message_store.messages) == 0 | ||
assert len(message_store_another.messages) == 0 | ||
|
||
|
||
def test_add_messages_graph_object() -> None: | ||
"""Basic testing: Passing driver through graph object.""" | ||
graph = FalkorDBGraph("NeverGonnaLetYouDownNevergonnagiveyouup") | ||
message_store = FalkorDBChatMessageHistory( | ||
"Gonnahavetoteachmehowtoloveyouagain", graph=graph | ||
) | ||
message_store.clear() | ||
assert len(message_store.messages) == 0 | ||
message_store.add_user_message("Hello! Language Chain!") | ||
message_store.add_ai_message("Hi Guys!") | ||
# Now check if the messages are stored in the database correctly | ||
assert len(message_store.messages) == 2 |