Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Maintenance LLM Queries and Partial Schema Retrieval #6

Merged
merged 9 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions core/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
from uuid import uuid4
import logging

from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node

logger = logging.getLogger(__name__)


class Edge(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
source_node: Node
target_node: Node
uuid: str = Field(default_factory=lambda: uuid4().hex)
source_node_uuid: str
target_node_uuid: str
created_at: datetime

@abstractmethod
Expand All @@ -30,20 +31,15 @@ def __eq__(self, other):

class EpisodicEdge(Edge):
async def save(self, driver: AsyncDriver):
if self.uuid is None:
uuid = uuid4()
logger.info(f"Created uuid: {uuid} for episodic edge")
self.uuid = str(uuid)

result = await driver.execute_query(
"""
MATCH (episode:Episodic {uuid: $episode_uuid})
MATCH (node:Entity {uuid: $entity_uuid})
MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node)
SET r = {uuid: $uuid, created_at: $created_at}
RETURN r.uuid AS uuid""",
episode_uuid=self.source_node.uuid,
entity_uuid=self.target_node.uuid,
episode_uuid=self.source_node_uuid,
entity_uuid=self.target_node_uuid,
uuid=self.uuid,
created_at=self.created_at,
)
Expand Down Expand Up @@ -79,10 +75,10 @@ class EntityEdge(Edge):
default=None, description="datetime of when the fact stopped being true"
)

def generate_embedding(self, embedder, model="text-embedding-3-large"):
async def generate_embedding(self, embedder, model="text-embedding-3-small"):
text = self.fact.replace("\n", " ")
embedding = embedder.create(input=[text], model=model).data[0].embedding
self.fact_embedding = embedding
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.fact_embedding = embedding[:EMBEDDING_DIM]

return embedding

Expand All @@ -96,8 +92,8 @@ async def save(self, driver: AsyncDriver):
episodes: $episodes, created_at: $created_at, expired_at: $expired_at,
valid_at: $valid_at, invalid_at: $invalid_at}
RETURN r.uuid AS uuid""",
source_uuid=self.source_node.uuid,
target_uuid=self.target_node.uuid,
source_uuid=self.source_node_uuid,
target_uuid=self.target_node_uuid,
uuid=self.uuid,
name=self.name,
fact=self.fact,
Expand Down
147 changes: 105 additions & 42 deletions core/graphiti.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, LiteralString, Tuple
from typing import Callable, LiteralString
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
import os

from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.utils import (
build_episodic_edges,
retrieve_relevant_schema,
Expand All @@ -16,6 +18,15 @@
retrieve_episodes,
)
from core.llm_client import LLMClient, OpenAIClient, LLMConfig
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
bfs,
get_relevant_nodes,
get_relevant_edges,
)

logger = logging.getLogger(__name__)

Expand All @@ -34,7 +45,7 @@ def __init__(
self.llm_client = OpenAIClient(
LLMConfig(
api_key=os.getenv("OPENAI_API_KEY"),
model="gpt-4o",
model="gpt-4o-mini",
base_url="https://api.openai.com/v1",
)
)
Expand Down Expand Up @@ -75,47 +86,78 @@ async def add_episode(
):
"""Process an episode and update the graph"""
try:
nodes: list[Node] = []
edges: list[Edge] = []
nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.client.embeddings
now = datetime.now()

previous_episodes = await self.retrieve_episodes(last_n=3)
episode = EpisodicNode(
name=name,
labels=[],
source="messages",
content=episode_body,
source_description=source_description,
created_at=datetime.now(),
created_at=now,
valid_at=reference_time,
)
# await episode.save(self.driver)
relevant_schema = await self.retrieve_relevant_schema(episode.content)
new_nodes = await extract_new_nodes(
self.llm_client, episode, relevant_schema, previous_episodes
# relevant_schema = await self.retrieve_relevant_schema(episode.content)

extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes
)

# Calculate Embeddings

await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)

existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)

new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
nodes.extend(new_nodes)
new_edges, affected_nodes = await extract_new_edges(
self.llm_client, episode, new_nodes, relevant_schema, previous_episodes

extracted_edges = await extract_edges(
self.llm_client, episode, new_nodes, previous_episodes
)
edges.extend(new_edges)
episodic_edges = build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
list(set(nodes + affected_nodes)),
episode,
datetime.now(),

await asyncio.gather(
*[edge.generate_embedding(embedder) for edge in extracted_edges]
)

existing_edges = await get_relevant_edges(extracted_edges, self.driver)

new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)

entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
nodes,
episode,
now,
)
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
nodes.append(episode)
logger.info(f"Built episodic edges: {episodic_edges}")
edges.extend(episodic_edges)

# invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
# )

# edges.extend(invalidated_edges)

# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
await asyncio.gather(*[edge.save(self.driver) for edge in edges])
await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges])
await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges])
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)
Expand All @@ -129,6 +171,10 @@ async def add_episode(

async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
Expand All @@ -143,13 +189,19 @@ async def build_indices(self):
for query in index_queries:
await self.driver.execute_query(query)

# Add the entity indices
# Add the semantic indices
await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]
"""
)

await self.driver.execute_query(
"""
CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact]
"""
)

await self.driver.execute_query(
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
Expand All @@ -161,29 +213,40 @@ async def build_indices(self):
"""
)

async def search(
self, query: str, config
) -> (list)[tuple[EntityNode, list[EntityEdge]]]:
(vec_nodes, vec_edges) = similarity_search(query, embedder)
(text_nodes, text_edges) = fulltext_search(query)
await self.driver.execute_query(
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
"""
)

async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
text = query.replace("\n", " ")
search_vector = (
(
await self.llm_client.client.embeddings.create(
input=[text], model="text-embedding-3-small"
)
)
.data[0]
.embedding[:EMBEDDING_DIM]
)

nodes = vec_nodes.extend(text_nodes)
edges = vec_edges.extend(text_edges)
edges = await edge_similarity_search(search_vector, self.driver)
nodes = await entity_fulltext_search(query, self.driver)

results = bfs(nodes, edges, k=1)
node_ids = [node.uuid for node in nodes]

episode_ids = ["Mode of episode ids"]
for edge in edges:
node_ids.append(edge.source_node_uuid)
node_ids.append(edge.target_node_uuid)

episodes = get_episodes(episode_ids[:episode_count])
node_ids = list(dict.fromkeys(node_ids))

return [(node, edges)], episodes
context = await bfs(node_ids, self.driver)

# Invalidate edges that are no longer valid
async def invalidate_edges(
self,
episode: EpisodicNode,
new_nodes: list[EntityNode],
new_edges: list[EntityEdge],
relevant_schema: dict[str, any],
previous_episodes: list[EpisodicNode],
): ...
return context
7 changes: 5 additions & 2 deletions core/llm_client/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
EMBEDDING_DIM = 1024


class LLMConfig:
"""
Configuration class for the Language Learning Model (LLM).
Expand All @@ -10,7 +13,7 @@ class LLMConfig:
def __init__(
self,
api_key: str,
model: str = "gpt-4o",
model: str = "gpt-4o-mini",
base_url: str = "https://api.openai.com",
):
"""
Expand All @@ -21,7 +24,7 @@ def __init__(
This is required for making authorized requests.

model (str, optional): The specific LLM model to use for generating responses.
Defaults to "gpt-4o", which appears to be a custom model name.
Defaults to "gpt-4o-mini", which appears to be a custom model name.
Common values might include "gpt-3.5-turbo" or "gpt-4".

base_url (str, optional): The base URL of the LLM API service.
Expand Down
17 changes: 15 additions & 2 deletions core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from neo4j import AsyncDriver
import logging

from core.llm_client.config import EMBEDDING_DIM

logger = logging.getLogger(__name__)


class Node(BaseModel, ABC):
uuid: str = Field(default_factory=lambda: str(uuid4()))
uuid: str = Field(default_factory=lambda: uuid4().hex)
name: str
labels: list[str] = Field(default_factory=list)
created_at: datetime
Expand Down Expand Up @@ -66,21 +68,32 @@ async def save(self, driver: AsyncDriver):


class EntityNode(Node):
name_embedding: list[float] | None = Field(
default=None, description="embedding of the name"
)
summary: str = Field(description="regional summary of surrounding edges")

async def update_summary(self, driver: AsyncDriver): ...

async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ...

async def generate_name_embedding(self, embedder, model="text-embedding-3-small"):
text = self.name.replace("\n", " ")
embedding = (await embedder.create(input=[text], model=model)).data[0].embedding
self.name_embedding = embedding[:EMBEDDING_DIM]

return embedding

async def save(self, driver: AsyncDriver):
result = await driver.execute_query(
"""
MERGE (n:Entity {uuid: $uuid})
SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at}
SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at}
RETURN n.uuid AS uuid""",
uuid=self.uuid,
name=self.name,
summary=self.summary,
name_embedding=self.name_embedding,
created_at=self.created_at,
)

Expand Down
Loading