Skip to content

Commit

Permalink
Update Maintenance LLM Queries and Partial Schema Retrieval (#6)
Browse files Browse the repository at this point in the history
* search updates

* add search_utils

* updates

* graph maintenance updates

* revert extract_new_nodes

* revert extract_new_edges

* parallelize node searching

* add edge fulltext search

* search optimizations
  • Loading branch information
prasmussen15 committed Aug 18, 2024
1 parent ad552b5 commit 4db3906
Show file tree
Hide file tree
Showing 16 changed files with 953 additions and 119 deletions.
26 changes: 11 additions & 15 deletions core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from uuid import uuid4
import logging

from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node

logger = logging.getLogger(__name__)


class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
source_node: Node
target_node: Node
uuid: str = Field(default_factory=lambda: uuid4().hex)
source_node_uuid: str
target_node_uuid: str
created_at: datetime

@abstractmethod
Expand All @@ -30,20 +31,15 @@ def __eq__(self, other):

class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid4()
logger.info(f"Created uuid: {uuid} for episodic edge")
self.uuid = str(uuid)

result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node.uuid,
entity_uuid=self.target_node.uuid,
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
created_at=self.created_at,
)
Expand Down Expand Up @@ -79,10 +75,10 @@ class EntityEdge(Edge):
default=None, description="datetime of when the fact stopped being true"
)

def generate_embedding(self, embedder, model="text-embedding-3-large"):
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
text = self.fact.replace("\n", " ")
embedding = embedder.create(input=[text], model=model).data[0].embedding
self.fact_embedding = embedding
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.fact_embedding = embedding[:EMBEDDING_DIM]

return embedding

Expand All @@ -96,8 +92,8 @@ async def save(self, driver: AsyncDriver):
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""",
source_uuid=self.source_node.uuid,
target_uuid=self.target_node.uuid,
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
name=self.name,
fact=self.fact,
Expand Down
147 changes: 105 additions & 42 deletions core/graphiti.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, LiteralString, Tuple
from typing import Callable, LiteralString
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
import os

from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.utils import (
build_episodic_edges,
retrieve_relevant_schema,
Expand All @@ -16,6 +18,15 @@
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
bfs,
get_relevant_nodes,
get_relevant_edges,
)

logger = logging.getLogger(__name__)

Expand All @@ -34,7 +45,7 @@ def __init__(
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o",
model="gpt-4o-mini",
base_url="https://api.openai.com/v1",
)
)
Expand Down Expand Up @@ -75,47 +86,78 @@ async def add_episode(
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
now = datetime.now()

previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode(
name=name,
labels=[],
source="messages",
content=episode_body,
source_description=source_description,
created_at=datetime.now(),
created_at=now,
valid_at=reference_time,
)
# await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await extract_new_nodes(
self.llm_client, episode, relevant_schema, previous_episodes
# relevant_schema = await self.retrieve_relevant_schema(episode.content)

extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes
)

# Calculate Embeddings

await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)

existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)

new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
nodes.extend(new_nodes)
new_edges, affected_nodes = await extract_new_edges(
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes

extracted_edges = await extract_edges(
self.llm_client, episode, new_nodes, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
list(set(nodes + affected_nodes)),
episode,
datetime.now(),

await asyncio.gather(
*[edge.generate_embedding(embedder) for edge in extracted_edges]
)

existing_edges = await get_relevant_edges(extracted_edges, self.driver)

new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)

entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
nodes,
episode,
now,
)
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
nodes.append(episode)
logger.info(f"Built episodic edges: {episodic_edges}")
edges.extend(episodic_edges)

# invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
# )

# edges.extend(invalidated_edges)

# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
Expand All @@ -129,6 +171,10 @@ async def add_episode(

async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
Expand All @@ -143,13 +189,19 @@ async def build_indices(self):
for query in index_queries:
await self.driver.execute_query(query)

# Add the entity indices
# Add the semantic indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)

await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
"""
)

await self.driver.execute_query(
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
Expand All @@ -161,29 +213,40 @@ async def build_indices(self):
"""
)

async def search(
self, query: str, config
) -> (list)[tuple[EntityNode, list[EntityEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
await self.driver.execute_query(
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)

async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
text = query.replace("\n", " ")
search_vector = (
(
await self.llm_client.client.embeddings.create(
input=[text], model="text-embedding-3-small"
)
)
.data[0]
.embedding[:EMBEDDING_DIM]
)

nodes = vec_nodes.extend(text_nodes)
edges = vec_edges.extend(text_edges)
edges = await edge_similarity_search(search_vector, self.driver)
nodes = await entity_fulltext_search(query, self.driver)

results = bfs(nodes, edges, k=1)
node_ids = [node.uuid for node in nodes]

episode_ids = ["Mode of episode ids"]
for edge in edges:
node_ids.append(edge.source_node_uuid)
node_ids.append(edge.target_node_uuid)

episodes = get_episodes(episode_ids[:episode_count])
node_ids = list(dict.fromkeys(node_ids))

return [(node, edges)], episodes
context = await bfs(node_ids, self.driver)

# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
return context
7 changes: 5 additions & 2 deletions core/llm_client/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
EMBEDDING_DIM = 1024


class LLMConfig:
"""
Configuration class for the Language Learning Model (LLM).
Expand All @@ -10,7 +13,7 @@ class LLMConfig:
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
model: str = "gpt-4o-mini",
base_url: str = "https://api.openai.com",
):
"""
Expand All @@ -21,7 +24,7 @@ def __init__(
This is required for making authorized requests.
model (str, optional): The specific LLM model to use for generating responses.
Defaults to "gpt-4o", which appears to be a custom model name.
Defaults to "gpt-4o-mini", which appears to be a custom model name.
Common values might include "gpt-3.5-turbo" or "gpt-4".
base_url (str, optional): The base URL of the LLM API service.
Expand Down
17 changes: 15 additions & 2 deletions core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from neo4j import AsyncDriver
import logging

from core.llm_client.config import EMBEDDING_DIM

logger = logging.getLogger(__name__)


class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
uuid: str = Field(default_factory=lambda: uuid4().hex)
name: str
labels: list[str] = Field(default_factory=list)
created_at: datetime
Expand Down Expand Up @@ -66,21 +68,32 @@ async def save(self, driver: AsyncDriver):


class EntityNode(Node):
name_embedding: list[float] | None = Field(
default=None, description="embedding of the name"
)
summary: str = Field(description="regional summary of surrounding edges")

async def update_summary(self, driver: AsyncDriver): ...

async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...

async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
text = self.name.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]

return embedding

async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at}
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
)

Expand Down
Loading

0 comments on commit 4db3906

Please sign in to comment.