From fdcb3d99ba2be42942d1d16bfcf2b6a4ab2bf1ed Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Mon, 9 Sep 2024 21:27:42 -0400 Subject: [PATCH 1/2] feat: Add entity node request + service maintenance --- Dockerfile | 5 +++-- docker-compose.yml | 3 ++- server/graph_service/app.py | 0 server/graph_service/dto/__init__.py | 3 ++- server/graph_service/dto/common.py | 3 ++- server/graph_service/dto/ingest.py | 7 +++++++ server/graph_service/dto/retrieve.py | 12 +++++++----- server/graph_service/routers/ingest.py | 16 +++++++++++++++- server/graph_service/routers/retrieve.py | 16 ++++++++++------ server/graph_service/zep_graphiti.py | 17 +++++++++++------ 10 files changed, 59 insertions(+), 23 deletions(-) create mode 100644 server/graph_service/app.py diff --git a/Dockerfile b/Dockerfile index 536f541e..724e1222 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,6 +37,7 @@ COPY ./server /app # Set environment variables ENV PYTHONUNBUFFERED=1 - +ENV PORT=8000 # Command to run the application -CMD ["uvicorn", "graph_service.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file + +CMD uvicorn graph_service.main:app --host 0.0.0.0 --port $PORT \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index f12ff00a..04889e33 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,7 +2,7 @@ version: '3.8' services: graph: - build: . + image: zepai/graphiti:latest ports: - "8000:8000" @@ -11,6 +11,7 @@ services: - NEO4J_URI=bolt://neo4j:${NEO4J_PORT} - NEO4J_USER=${NEO4J_USER} - NEO4J_PASSWORD=${NEO4J_PASSWORD} + - PORT=8000 neo4j: image: neo4j:5.22.0 diff --git a/server/graph_service/app.py b/server/graph_service/app.py new file mode 100644 index 00000000..e69de29b diff --git a/server/graph_service/dto/__init__.py b/server/graph_service/dto/__init__.py index 381ae68b..284b26b1 100644 --- a/server/graph_service/dto/__init__.py +++ b/server/graph_service/dto/__init__.py @@ -1,5 +1,5 @@ from .common import Message, Result -from .ingest import AddMessagesRequest +from .ingest import AddEntityNodeRequest, AddMessagesRequest from .retrieve import ( FactResult, GetMemoryRequest, @@ -12,6 +12,7 @@ 'SearchQuery', 'Message', 'AddMessagesRequest', + 'AddEntityNodeRequest', 'SearchResults', 'FactResult', 'Result', diff --git a/server/graph_service/dto/common.py b/server/graph_service/dto/common.py index a1c379be..9d9e4b76 100644 --- a/server/graph_service/dto/common.py +++ b/server/graph_service/dto/common.py @@ -11,8 +11,9 @@ class Result(BaseModel): class Message(BaseModel): content: str = Field(..., description='The content of the message') + uuid: str | None = Field(default=None, description='The uuid of the message (optional)') name: str = Field( - default='', description='The name of the episodic node for the message (message uuid)' + default='', description='The name of the episodic node for the message (optional)' ) role_type: Literal['user', 'assistant', 'system'] = Field( ..., description='The role type of the message (user, assistant or system)' diff --git a/server/graph_service/dto/ingest.py b/server/graph_service/dto/ingest.py index 23a7d31c..ae6c1f80 100644 --- a/server/graph_service/dto/ingest.py +++ b/server/graph_service/dto/ingest.py @@ -6,3 +6,10 @@ class AddMessagesRequest(BaseModel): group_id: str = Field(..., description='The group id of the messages to add') messages: list[Message] = Field(..., description='The messages to add') + + +class AddEntityNodeRequest(BaseModel): + uuid: str = Field(..., description='The uuid of the node to add') + group_id: str = Field(..., description='The group id of the node to add') + name: str = Field(..., description='The name of the node to add') + summary: str | None = Field(None, description='The summary of the node to add') diff --git a/server/graph_service/dto/retrieve.py b/server/graph_service/dto/retrieve.py index 6cbecf04..5b02a3f3 100644 --- a/server/graph_service/dto/retrieve.py +++ b/server/graph_service/dto/retrieve.py @@ -1,5 +1,4 @@ -from datetime import datetime -from typing import Literal +from datetime import datetime, timezone from pydantic import BaseModel, Field @@ -10,9 +9,6 @@ class SearchQuery(BaseModel): group_id: str = Field(..., description='The group id of the memory to get') query: str max_facts: int = Field(default=10, description='The maximum number of facts to retrieve') - search_type: Literal['facts', 'user_centered_facts'] = Field( - default='facts', description='The type of search to perform' - ) class FactResult(BaseModel): @@ -24,6 +20,9 @@ class FactResult(BaseModel): created_at: datetime expired_at: datetime | None + class Config: + json_encoders = {datetime: lambda v: v.astimezone(timezone.utc).isoformat()} + class SearchResults(BaseModel): facts: list[FactResult] @@ -32,6 +31,9 @@ class SearchResults(BaseModel): class GetMemoryRequest(BaseModel): group_id: str = Field(..., description='The group id of the memory to get') max_facts: int = Field(default=10, description='The maximum number of facts to retrieve') + center_node_uuid: str | None = Field( + ..., description='The uuid of the node to center the retrieval on' + ) messages: list[Message] = Field( ..., description='The messages to build the retrieval query from ' ) diff --git a/server/graph_service/routers/ingest.py b/server/graph_service/routers/ingest.py index e98013c1..590f2eac 100644 --- a/server/graph_service/routers/ingest.py +++ b/server/graph_service/routers/ingest.py @@ -6,7 +6,7 @@ from graphiti_core.nodes import EpisodeType # type: ignore from graphiti_core.utils import clear_data # type: ignore -from graph_service.dto import AddMessagesRequest, Message, Result +from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result from graph_service.zep_graphiti import ZepGraphitiDep @@ -69,6 +69,20 @@ async def add_messages_task(m: Message): return Result(message='Messages added to processing queue', success=True) +@router.post('/entity-node', status_code=status.HTTP_201_CREATED) +async def add_entity_node( + request: AddEntityNodeRequest, + graphiti: ZepGraphitiDep, +): + node = await graphiti.save_entity_node( + uuid=request.uuid, + group_id=request.group_id, + name=request.name, + summary=request.summary, + ) + return node + + @router.post('/clear', status_code=status.HTTP_200_OK) async def clear( graphiti: ZepGraphitiDep, diff --git a/server/graph_service/routers/retrieve.py b/server/graph_service/routers/retrieve.py index 8f449edf..4ee3a4ac 100644 --- a/server/graph_service/routers/retrieve.py +++ b/server/graph_service/routers/retrieve.py @@ -1,3 +1,5 @@ +from datetime import datetime + from fastapi import APIRouter, status from graph_service.dto import ( @@ -14,16 +16,10 @@ @router.post('/search', status_code=status.HTTP_200_OK) async def search(query: SearchQuery, graphiti: ZepGraphitiDep): - center_node_uuid: str | None = None - if query.search_type == 'user_centered_facts': - user_node = await graphiti.get_user_node(query.group_id) - if user_node: - center_node_uuid = user_node.uuid relevant_edges = await graphiti.search( group_ids=[query.group_id], query=query.query, num_results=query.max_facts, - center_node_uuid=center_node_uuid, ) facts = [get_fact_result_from_edge(edge) for edge in relevant_edges] return SearchResults( @@ -31,6 +27,14 @@ async def search(query: SearchQuery, graphiti: ZepGraphitiDep): ) +@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK) +async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep): + episodes = await graphiti.retrieve_episodes( + group_ids=[group_id], last_n=last_n, reference_time=datetime.now() + ) + return episodes + + @router.post('/get-memory', status_code=status.HTTP_200_OK) async def get_memory( request: GetMemoryRequest, diff --git a/server/graph_service/zep_graphiti.py b/server/graph_service/zep_graphiti.py index af1b8c0c..9b98b35a 100644 --- a/server/graph_service/zep_graphiti.py +++ b/server/graph_service/zep_graphiti.py @@ -11,13 +11,19 @@ class ZepGraphiti(Graphiti): - def __init__( - self, uri: str, user: str, password: str, user_id: str, llm_client: LLMClient | None = None - ): + def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): super().__init__(uri, user, password, llm_client) - self.user_id = user_id - async def get_user_node(self, user_id: str) -> EntityNode | None: ... + async def save_entity_node(self, name: str, uuid: str, group_id: str, summary: str = ''): + new_node = EntityNode( + name=name, + uuid=uuid, + group_id=group_id, + summary=summary, + ) + await new_node.generate_name_embedding(self.llm_client.get_embedder()) + await new_node.save(self.driver) + return new_node async def get_graphiti(settings: ZepEnvDep): @@ -25,7 +31,6 @@ async def get_graphiti(settings: ZepEnvDep): uri=settings.neo4j_uri, user=settings.neo4j_user, password=settings.neo4j_password, - user_id='test1234', ) try: yield client From a73f9443d67ae7749ce0ecef07f0bdd3d89dc4d4 Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Mon, 9 Sep 2024 21:29:04 -0400 Subject: [PATCH 2/2] chore: Fix linter --- server/graph_service/dto/ingest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/graph_service/dto/ingest.py b/server/graph_service/dto/ingest.py index ae6c1f80..9b0159c8 100644 --- a/server/graph_service/dto/ingest.py +++ b/server/graph_service/dto/ingest.py @@ -12,4 +12,4 @@ class AddEntityNodeRequest(BaseModel): uuid: str = Field(..., description='The uuid of the node to add') group_id: str = Field(..., description='The group id of the node to add') name: str = Field(..., description='The name of the node to add') - summary: str | None = Field(None, description='The summary of the node to add') + summary: str = Field(default='', description='The summary of the node to add')