Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(graph-service): add entity node handling and update Docker configurations #100

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ COPY ./server /app

# Set environment variables
ENV PYTHONUNBUFFERED=1

ENV PORT=8000
# Command to run the application
CMD ["uvicorn", "graph_service.main:app", "--host", "0.0.0.0", "--port", "8000"]

CMD uvicorn graph_service.main:app --host 0.0.0.0 --port $PORT
3 changes: 2 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ version: '3.8'

services:
graph:
build: .
image: zepai/graphiti:latest
ports:
- "8000:8000"

Expand All @@ -11,6 +11,7 @@ services:
- NEO4J_URI=bolt://neo4j:${NEO4J_PORT}
- NEO4J_USER=${NEO4J_USER}
- NEO4J_PASSWORD=${NEO4J_PASSWORD}
- PORT=8000
neo4j:
image: neo4j:5.22.0

Expand Down
Empty file added server/graph_service/app.py
Empty file.
3 changes: 2 additions & 1 deletion server/graph_service/dto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .common import Message, Result
from .ingest import AddMessagesRequest
from .ingest import AddEntityNodeRequest, AddMessagesRequest
from .retrieve import (
FactResult,
GetMemoryRequest,
Expand All @@ -12,6 +12,7 @@
'SearchQuery',
'Message',
'AddMessagesRequest',
'AddEntityNodeRequest',
'SearchResults',
'FactResult',
'Result',
Expand Down
3 changes: 2 additions & 1 deletion server/graph_service/dto/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ class Result(BaseModel):

class Message(BaseModel):
content: str = Field(..., description='The content of the message')
uuid: str | None = Field(default=None, description='The uuid of the message (optional)')
name: str = Field(
default='', description='The name of the episodic node for the message (message uuid)'
default='', description='The name of the episodic node for the message (optional)'
)
role_type: Literal['user', 'assistant', 'system'] = Field(
..., description='The role type of the message (user, assistant or system)'
Expand Down
7 changes: 7 additions & 0 deletions server/graph_service/dto/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@
class AddMessagesRequest(BaseModel):
group_id: str = Field(..., description='The group id of the messages to add')
messages: list[Message] = Field(..., description='The messages to add')


class AddEntityNodeRequest(BaseModel):
uuid: str = Field(..., description='The uuid of the node to add')
group_id: str = Field(..., description='The group id of the node to add')
name: str = Field(..., description='The name of the node to add')
summary: str = Field(default='', description='The summary of the node to add')
12 changes: 7 additions & 5 deletions server/graph_service/dto/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Literal
from datetime import datetime, timezone

from pydantic import BaseModel, Field

Expand All @@ -10,9 +9,6 @@ class SearchQuery(BaseModel):
group_id: str = Field(..., description='The group id of the memory to get')
query: str
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
search_type: Literal['facts', 'user_centered_facts'] = Field(
default='facts', description='The type of search to perform'
)


class FactResult(BaseModel):
Expand All @@ -24,6 +20,9 @@ class FactResult(BaseModel):
created_at: datetime
expired_at: datetime | None

class Config:
json_encoders = {datetime: lambda v: v.astimezone(timezone.utc).isoformat()}


class SearchResults(BaseModel):
facts: list[FactResult]
Expand All @@ -32,6 +31,9 @@ class SearchResults(BaseModel):
class GetMemoryRequest(BaseModel):
group_id: str = Field(..., description='The group id of the memory to get')
max_facts: int = Field(default=10, description='The maximum number of facts to retrieve')
center_node_uuid: str | None = Field(
..., description='The uuid of the node to center the retrieval on'
)
messages: list[Message] = Field(
..., description='The messages to build the retrieval query from '
)
Expand Down
16 changes: 15 additions & 1 deletion server/graph_service/routers/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from graphiti_core.nodes import EpisodeType # type: ignore
from graphiti_core.utils import clear_data # type: ignore

from graph_service.dto import AddMessagesRequest, Message, Result
from graph_service.dto import AddEntityNodeRequest, AddMessagesRequest, Message, Result
from graph_service.zep_graphiti import ZepGraphitiDep


Expand Down Expand Up @@ -69,6 +69,20 @@ async def add_messages_task(m: Message):
return Result(message='Messages added to processing queue', success=True)


@router.post('/entity-node', status_code=status.HTTP_201_CREATED)
async def add_entity_node(
request: AddEntityNodeRequest,
graphiti: ZepGraphitiDep,
):
node = await graphiti.save_entity_node(
uuid=request.uuid,
group_id=request.group_id,
name=request.name,
summary=request.summary,
)
return node


@router.post('/clear', status_code=status.HTTP_200_OK)
async def clear(
graphiti: ZepGraphitiDep,
Expand Down
16 changes: 10 additions & 6 deletions server/graph_service/routers/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import datetime

from fastapi import APIRouter, status

from graph_service.dto import (
Expand All @@ -14,23 +16,25 @@

@router.post('/search', status_code=status.HTTP_200_OK)
async def search(query: SearchQuery, graphiti: ZepGraphitiDep):
center_node_uuid: str | None = None
if query.search_type == 'user_centered_facts':
user_node = await graphiti.get_user_node(query.group_id)
if user_node:
center_node_uuid = user_node.uuid
relevant_edges = await graphiti.search(
group_ids=[query.group_id],
query=query.query,
num_results=query.max_facts,
center_node_uuid=center_node_uuid,
)
facts = [get_fact_result_from_edge(edge) for edge in relevant_edges]
return SearchResults(
facts=facts,
)


@router.get('/episodes/{group_id}', status_code=status.HTTP_200_OK)
async def get_episodes(group_id: str, last_n: int, graphiti: ZepGraphitiDep):
episodes = await graphiti.retrieve_episodes(
group_ids=[group_id], last_n=last_n, reference_time=datetime.now()
)
return episodes


@router.post('/get-memory', status_code=status.HTTP_200_OK)
async def get_memory(
request: GetMemoryRequest,
Expand Down
17 changes: 11 additions & 6 deletions server/graph_service/zep_graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,26 @@


class ZepGraphiti(Graphiti):
def __init__(
self, uri: str, user: str, password: str, user_id: str, llm_client: LLMClient | None = None
):
def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None):
super().__init__(uri, user, password, llm_client)
self.user_id = user_id

async def get_user_node(self, user_id: str) -> EntityNode | None: ...
async def save_entity_node(self, name: str, uuid: str, group_id: str, summary: str = ''):
new_node = EntityNode(
name=name,
uuid=uuid,
group_id=group_id,
summary=summary,
)
await new_node.generate_name_embedding(self.llm_client.get_embedder())
await new_node.save(self.driver)
return new_node


async def get_graphiti(settings: ZepEnvDep):
client = ZepGraphiti(
uri=settings.neo4j_uri,
user=settings.neo4j_user,
password=settings.neo4j_password,
user_id='test1234',
)
try:
yield client
Expand Down