Skip to content

Commit

Permalink
feat(graph-service): add entity node handling and update Docker confi…
Browse files Browse the repository at this point in the history
…gurations (#100)

* feat: Add entity node request + service maintenance

* chore: Fix linter
  • Loading branch information
paul-paliychuk committed Sep 10, 2024
1 parent 3f12254 commit ad2962c
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 23 deletions.
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ COPY ./server /app

# Set environment variables
ENV PYTHONUNBUFFERED=1

ENV PORT=8000
# Command to run the application
CMD ["uvicorn", "graph_service.main:app", "--host", "0.0.0.0", "--port", "8000"]

CMD uvicorn graph_service.main:app --host 0.0.0.0 --port $PORT
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ version: '3.8'

services:
graph:
build: .
image: zepai/graphiti:latest
ports:
- "8000:8000"

Expand All @@ -11,6 +11,7 @@ services:
- NEO4J_URI=bolt://neo4j:${NEO4J_PORT}
- NEO4J_USER=${NEO4J_USER}
- NEO4J_PASSWORD=${NEO4J_PASSWORD}
- PORT=8000
neo4j:
image: neo4j:5.22.0

Expand Down
Empty file added server/graph_service/app.py
Empty file.
3 changes: 2 additions & 1 deletion server/graph_service/dto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .common import Message, Result
from .ingest import AddMessagesRequest
from .ingest import AddEntityNodeRequest, AddMessagesRequest
from .retrieve import (
FactResult,
GetMemoryRequest,
Expand All @@ -12,6 +12,7 @@
'SearchQuery',
'Message',
'AddMessagesRequest',
'AddEntityNodeRequest',
'SearchResults',
'FactResult',
'Result',
Expand Down
3 changes: 2 additions & 1 deletion server/graph_service/dto/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ class Result(BaseModel):

class Message(BaseModel):
content: str = Field(..., description='The content of the message')
uuid: str | None = Field(default=None, description='The uuid of the message (optional)')
name: str = Field(
default='', description='The name of the episodic node for the message (message uuid)'
default='', description='The name of the episodic node for the message (optional)'
)
role_type: Literal['user', 'assistant', 'system'] = Field(
..., description='The role type of the message (user, assistant or system)'
Expand Down
7 changes: 7 additions & 0 deletions server/graph_service/dto/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@
class AddMessagesRequest(BaseModel):
group_id: str = Field(..., description='The group id of the messages to add')
messages: list[Message] = Field(..., description='The messages to add')


class AddEntityNodeRequest(BaseModel):
uuid: str = Field(..., description='The uuid of the node to add')
group_id: str = Field(..., description='The group id of the node to add')
name: str = Field(..., description='The name of the node to add')
summary: str = Field(default='', description='The summary of the node to add')
12 changes: 7 additions & 5 deletions server/graph_service/dto/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Literal
from datetime import datetime, timezone

from pydantic import BaseModel, Field

Expand All @@ -10,9 +9,6 @@ class SearchQuery(BaseModel):
group_id: str = Field(..., description='The group id of the memory to get')
query: str
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
search_type: Literal['facts', 'user_centered_facts'] = Field(
default='facts', description='The type of search to perform'
)


class FactResult(BaseModel):
Expand All @@ -24,6 +20,9 @@ class FactResult(BaseModel):
created_at: datetime
expired_at: datetime | None

class Config:
json_encoders = {datetime: lambda v: v.astimezone(timezone.utc).isoformat()}


class SearchResults(BaseModel):
facts: list[FactResult]
Expand All @@ -32,6 +31,9 @@ class SearchResults(BaseModel):
class GetMemoryRequest(BaseModel):
group_id: str = Field(..., description='The group id of the memory to get')
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
center_node_uuid: str | None = Field(
..., description='The uuid of the node to center the retrieval on'
)
messages: list[Message] = Field(
..., description='The messages to build the retrieval query from '
)
Expand Down
16 changes: 15 additions & 1 deletion server/graph_service/routers/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graphiti_core.nodes import EpisodeType # type: ignore
from graphiti_core.utils import clear_data # type: ignore

from graph_service.dto import AddMessagesRequest, Message, Result
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
from graph_service.zep_graphiti import ZepGraphitiDep


Expand Down Expand Up @@ -69,6 +69,20 @@ async def add_messages_task(m: Message):
return Result(message='Messages added to processing queue', success=True)


@router.post('/entity-node', status_code=status.HTTP_201_CREATED)
async def add_entity_node(
request: AddEntityNodeRequest,
graphiti: ZepGraphitiDep,
):
node = await graphiti.save_entity_node(
uuid=request.uuid,
group_id=request.group_id,
name=request.name,
summary=request.summary,
)
return node


@router.post('/clear', status_code=status.HTTP_200_OK)
async def clear(
graphiti: ZepGraphitiDep,
Expand Down
16 changes: 10 additions & 6 deletions server/graph_service/routers/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

from fastapi import APIRouter, status

from graph_service.dto import (
Expand All @@ -14,23 +16,25 @@

@router.post('/search', status_code=status.HTTP_200_OK)
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
center_node_uuid: str | None = None
if query.search_type == 'user_centered_facts':
user_node = await graphiti.get_user_node(query.group_id)
if user_node:
center_node_uuid = user_node.uuid
relevant_edges = await graphiti.search(
group_ids=[query.group_id],
query=query.query,
num_results=query.max_facts,
center_node_uuid=center_node_uuid,
)
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
return SearchResults(
facts=facts,
)


@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
episodes = await graphiti.retrieve_episodes(
group_ids=[group_id], last_n=last_n, reference_time=datetime.now()
)
return episodes


@router.post('/get-memory', status_code=status.HTTP_200_OK)
async def get_memory(
request: GetMemoryRequest,
Expand Down
17 changes: 11 additions & 6 deletions server/graph_service/zep_graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,26 @@


class ZepGraphiti(Graphiti):
def __init__(
self, uri: str, user: str, password: str, user_id: str, llm_client: LLMClient | None = None
):
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
super().__init__(uri, user, password, llm_client)
self.user_id = user_id

async def get_user_node(self, user_id: str) -> EntityNode | None: ...
async def save_entity_node(self, name: str, uuid: str, group_id: str, summary: str = ''):
new_node = EntityNode(
name=name,
uuid=uuid,
group_id=group_id,
summary=summary,
)
await new_node.generate_name_embedding(self.llm_client.get_embedder())
await new_node.save(self.driver)
return new_node


async def get_graphiti(settings: ZepEnvDep):
client = ZepGraphiti(
uri=settings.neo4j_uri,
user=settings.neo4j_user,
password=settings.neo4j_password,
user_id='test1234',
)
try:
yield client
Expand Down

0 comments on commit ad2962c

Please sign in to comment.