diff --git a/core/edges.py b/core/edges.py index 48d8fb29..fee458f9 100644 --- a/core/edges.py +++ b/core/edges.py @@ -5,15 +5,16 @@ from uuid import uuid4 import logging +from core.llm_client.config import EMBEDDING_DIM from core.nodes import Node logger = logging.getLogger(__name__) class Edge(BaseModel, ABC): - uuid: str = Field(default_factory=lambda: str(uuid4())) - source_node: Node - target_node: Node + uuid: str = Field(default_factory=lambda: uuid4().hex) + source_node_uuid: str + target_node_uuid: str created_at: datetime @abstractmethod @@ -30,11 +31,6 @@ def __eq__(self, other): class EpisodicEdge(Edge): async def save(self, driver: AsyncDriver): - if self.uuid is None: - uuid = uuid4() - logger.info(f"Created uuid: {uuid} for episodic edge") - self.uuid = str(uuid) - result = await driver.execute_query( """ MATCH (episode:Episodic {uuid: $episode_uuid}) @@ -42,8 +38,8 @@ async def save(self, driver: AsyncDriver): MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) SET r = {uuid: $uuid, created_at: $created_at} RETURN r.uuid AS uuid""", - episode_uuid=self.source_node.uuid, - entity_uuid=self.target_node.uuid, + episode_uuid=self.source_node_uuid, + entity_uuid=self.target_node_uuid, uuid=self.uuid, created_at=self.created_at, ) @@ -79,10 +75,10 @@ class EntityEdge(Edge): default=None, description="datetime of when the fact stopped being true" ) - def generate_embedding(self, embedder, model="text-embedding-3-large"): + async def generate_embedding(self, embedder, model="text-embedding-3-small"): text = self.fact.replace("\n", " ") - embedding = embedder.create(input=[text], model=model).data[0].embedding - self.fact_embedding = embedding + embedding = (await embedder.create(input=[text], model=model)).data[0].embedding + self.fact_embedding = embedding[:EMBEDDING_DIM] return embedding @@ -96,8 +92,8 @@ async def save(self, driver: AsyncDriver): episodes: $episodes, created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at} RETURN r.uuid AS uuid""", - source_uuid=self.source_node.uuid, - target_uuid=self.target_node.uuid, + source_uuid=self.source_node_uuid, + target_uuid=self.target_node_uuid, uuid=self.uuid, name=self.name, fact=self.fact, diff --git a/core/graphiti.py b/core/graphiti.py index bc01866b..83b40d1f 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -1,12 +1,14 @@ import asyncio from datetime import datetime import logging -from typing import Callable, LiteralString, Tuple +from typing import Callable, LiteralString from neo4j import AsyncGraphDatabase from dotenv import load_dotenv import os + +from core.llm_client.config import EMBEDDING_DIM from core.nodes import EntityNode, EpisodicNode, Node -from core.edges import EntityEdge, Edge +from core.edges import EntityEdge, Edge, EpisodicEdge from core.utils import ( build_episodic_edges, retrieve_relevant_schema, @@ -16,6 +18,15 @@ retrieve_episodes, ) from core.llm_client import LLMClient, OpenAIClient, LLMConfig +from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges +from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes +from core.utils.search.search_utils import ( + edge_similarity_search, + entity_fulltext_search, + bfs, + get_relevant_nodes, + get_relevant_edges, +) logger = logging.getLogger(__name__) @@ -34,7 +45,7 @@ def __init__( self.llm_client = OpenAIClient( LLMConfig( api_key=os.getenv("OPENAI_API_KEY"), - model="gpt-4o", + model="gpt-4o-mini", base_url="https://api.openai.com/v1", ) ) @@ -75,8 +86,12 @@ async def add_episode( ): """Process an episode and update the graph""" try: - nodes: list[Node] = [] - edges: list[Edge] = [] + nodes: list[EntityNode] = [] + entity_edges: list[EntityEdge] = [] + episodic_edges: list[EpisodicEdge] = [] + embedder = self.llm_client.client.embeddings + now = datetime.now() + previous_episodes = await self.retrieve_episodes(last_n=3) episode = EpisodicNode( name=name, @@ -84,38 +99,65 @@ async def add_episode( source="messages", content=episode_body, source_description=source_description, - created_at=datetime.now(), + created_at=now, valid_at=reference_time, ) - # await episode.save(self.driver) - relevant_schema = await self.retrieve_relevant_schema(episode.content) - new_nodes = await extract_new_nodes( - self.llm_client, episode, relevant_schema, previous_episodes + # relevant_schema = await self.retrieve_relevant_schema(episode.content) + + extracted_nodes = await extract_nodes( + self.llm_client, episode, previous_episodes + ) + + # Calculate Embeddings + + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes] + ) + + existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver) + + new_nodes = await dedupe_extracted_nodes( + self.llm_client, extracted_nodes, existing_nodes ) nodes.extend(new_nodes) - new_edges, affected_nodes = await extract_new_edges( - self.llm_client, episode, new_nodes, relevant_schema, previous_episodes + + extracted_edges = await extract_edges( + self.llm_client, episode, new_nodes, previous_episodes ) - edges.extend(new_edges) - episodic_edges = build_episodic_edges( - # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them - list(set(nodes + affected_nodes)), - episode, - datetime.now(), + + await asyncio.gather( + *[edge.generate_embedding(embedder) for edge in extracted_edges] + ) + + existing_edges = await get_relevant_edges(extracted_edges, self.driver) + + new_edges = await dedupe_extracted_edges( + self.llm_client, extracted_edges, existing_edges + ) + + entity_edges.extend(new_edges) + episodic_edges.extend( + build_episodic_edges( + # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them + nodes, + episode, + now, + ) ) # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built - nodes.append(episode) logger.info(f"Built episodic edges: {episodic_edges}") - edges.extend(episodic_edges) # invalidated_edges = await self.invalidate_edges( # episode, new_nodes, new_edges, relevant_schema, previous_episodes # ) # edges.extend(invalidated_edges) + # Future optimization would be using batch operations to save nodes and edges + await episode.save(self.driver) await asyncio.gather(*[node.save(self.driver) for node in nodes]) - await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) # for node in nodes: # if isinstance(node, EntityNode): # await node.update_summary(self.driver) @@ -129,6 +171,10 @@ async def add_episode( async def build_indices(self): index_queries: list[LiteralString] = [ + "CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)", + "CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)", + "CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.uuid)", + "CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[r:MENTIONS]-() ON (r.uuid)", "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", "CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)", "CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)", @@ -143,13 +189,19 @@ async def build_indices(self): for query in index_queries: await self.driver.execute_query(query) - # Add the entity indices + # Add the semantic indices await self.driver.execute_query( """ CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary] """ ) + await self.driver.execute_query( + """ + CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON EACH [r.name, r.fact] + """ + ) + await self.driver.execute_query( """ CREATE VECTOR INDEX fact_embedding IF NOT EXISTS @@ -161,29 +213,40 @@ async def build_indices(self): """ ) - async def search( - self, query: str, config - ) -> (list)[tuple[EntityNode, list[EntityEdge]]]: - (vec_nodes, vec_edges) = similarity_search(query, embedder) - (text_nodes, text_edges) = fulltext_search(query) + await self.driver.execute_query( + """ + CREATE VECTOR INDEX name_embedding IF NOT EXISTS + FOR (n:Entity) ON (n.name_embedding) + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """ + ) + + async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]: + text = query.replace("\n", " ") + search_vector = ( + ( + await self.llm_client.client.embeddings.create( + input=[text], model="text-embedding-3-small" + ) + ) + .data[0] + .embedding[:EMBEDDING_DIM] + ) - nodes = vec_nodes.extend(text_nodes) - edges = vec_edges.extend(text_edges) + edges = await edge_similarity_search(search_vector, self.driver) + nodes = await entity_fulltext_search(query, self.driver) - results = bfs(nodes, edges, k=1) + node_ids = [node.uuid for node in nodes] - episode_ids = ["Mode of episode ids"] + for edge in edges: + node_ids.append(edge.source_node_uuid) + node_ids.append(edge.target_node_uuid) - episodes = get_episodes(episode_ids[:episode_count]) + node_ids = list(dict.fromkeys(node_ids)) - return [(node, edges)], episodes + context = await bfs(node_ids, self.driver) - # Invalidate edges that are no longer valid - async def invalidate_edges( - self, - episode: EpisodicNode, - new_nodes: list[EntityNode], - new_edges: list[EntityEdge], - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], - ): ... + return context diff --git a/core/llm_client/config.py b/core/llm_client/config.py index af9dad41..d372a5b3 100644 --- a/core/llm_client/config.py +++ b/core/llm_client/config.py @@ -1,3 +1,6 @@ +EMBEDDING_DIM = 1024 + + class LLMConfig: """ Configuration class for the Language Learning Model (LLM). @@ -10,7 +13,7 @@ class LLMConfig: def __init__( self, api_key: str, - model: str = "gpt-4o", + model: str = "gpt-4o-mini", base_url: str = "https://api.openai.com", ): """ @@ -21,7 +24,7 @@ def __init__( This is required for making authorized requests. model (str, optional): The specific LLM model to use for generating responses. - Defaults to "gpt-4o", which appears to be a custom model name. + Defaults to "gpt-4o-mini", which appears to be a custom model name. Common values might include "gpt-3.5-turbo" or "gpt-4". base_url (str, optional): The base URL of the LLM API service. diff --git a/core/nodes.py b/core/nodes.py index 30129b51..93fd4b97 100644 --- a/core/nodes.py +++ b/core/nodes.py @@ -8,11 +8,13 @@ from neo4j import AsyncDriver import logging +from core.llm_client.config import EMBEDDING_DIM + logger = logging.getLogger(__name__) class Node(BaseModel, ABC): - uuid: str = Field(default_factory=lambda: str(uuid4())) + uuid: str = Field(default_factory=lambda: uuid4().hex) name: str labels: list[str] = Field(default_factory=list) created_at: datetime @@ -66,21 +68,32 @@ async def save(self, driver: AsyncDriver): class EntityNode(Node): + name_embedding: list[float] | None = Field( + default=None, description="embedding of the name" + ) summary: str = Field(description="regional summary of surrounding edges") async def update_summary(self, driver: AsyncDriver): ... async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... + async def generate_name_embedding(self, embedder, model="text-embedding-3-small"): + text = self.name.replace("\n", " ") + embedding = (await embedder.create(input=[text], model=model)).data[0].embedding + self.name_embedding = embedding[:EMBEDDING_DIM] + + return embedding + async def save(self, driver: AsyncDriver): result = await driver.execute_query( """ MERGE (n:Entity {uuid: $uuid}) - SET n = {uuid: $uuid, name: $name, summary: $summary, created_at: $created_at} + SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at} RETURN n.uuid AS uuid""", uuid=self.uuid, name=self.name, summary=self.summary, + name_embedding=self.name_embedding, created_at=self.created_at, ) diff --git a/core/prompts/dedupe_edges.py b/core/prompts/dedupe_edges.py new file mode 100644 index 00000000..3506a3b7 --- /dev/null +++ b/core/prompts/dedupe_edges.py @@ -0,0 +1,56 @@ +import json +from typing import TypedDict, Protocol + +from .models import Message, PromptVersion, PromptFunction + + +class Prompt(Protocol): + v1: PromptVersion + + +class Versions(TypedDict): + v1: PromptFunction + + +def v1(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that de-duplicates relationship from edge lists.", + ), + Message( + role="user", + content=f""" + Given the following context, deduplicate edges from a list of new edges given a list of existing edges: + + Existing Edges: + {json.dumps(context['existing_edges'], indent=2)} + + New Edges: + {json.dumps(context['extracted_edges'], indent=2)} + + Task: + 1. start with the list of edges from New Edges + 2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing + edge in the list + 3. Respond with the resulting list of edges + + Guidelines: + 1. Use both the name and fact of edges to determine if they are duplicates, + duplicate edges may have different names + + Respond with a JSON object in the following format: + {{ + "new_edges": [ + {{ + "name": "Unique identifier for the edge", + "fact": "one sentence description of the fact" + }} + ] + }} + """, + ), + ] + + +versions: Versions = {"v1": v1} diff --git a/core/prompts/dedupe_nodes.py b/core/prompts/dedupe_nodes.py new file mode 100644 index 00000000..798942b5 --- /dev/null +++ b/core/prompts/dedupe_nodes.py @@ -0,0 +1,56 @@ +import json +from typing import TypedDict, Protocol + +from .models import Message, PromptVersion, PromptFunction + + +class Prompt(Protocol): + v1: PromptVersion + + +class Versions(TypedDict): + v1: PromptFunction + + +def v1(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that de-duplicates nodes from node lists.", + ), + Message( + role="user", + content=f""" + Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes: + + Existing Nodes: + {json.dumps(context['existing_nodes'], indent=2)} + + New Nodes: + {json.dumps(context['extracted_nodes'], indent=2)} + + Task: + 1. start with the list of nodes from New Nodes + 2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing + node in the list + 3. Respond with the resulting list of nodes + + Guidelines: + 1. Use both the name and summary of nodes to determine if they are duplicates, + duplicate nodes may have different names + + Respond with a JSON object in the following format: + {{ + "new_nodes": [ + {{ + "name": "Unique identifier for the node", + "summary": "Brief summary of the node's role or significance" + }} + ] + }} + """, + ), + ] + + +versions: Versions = {"v1": v1} diff --git a/core/prompts/extract_edges.py b/core/prompts/extract_edges.py index 9e9c3750..cc12206a 100644 --- a/core/prompts/extract_edges.py +++ b/core/prompts/extract_edges.py @@ -6,10 +6,12 @@ class Prompt(Protocol): v1: PromptVersion + v2: PromptVersion class Versions(TypedDict): v1: PromptFunction + v2: PromptFunction def v1(context: dict[str, any]) -> list[Message]: @@ -68,6 +70,108 @@ def v1(context: dict[str, any]) -> list[Message]: ] -versions: Versions = { - "v1": v1, -} +def v1(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that extracts graph edges from provided context.", + ), + Message( + role="user", + content=f""" + Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph: + + Current Graph Structure: + {context['relevant_schema']} + + New Nodes: + {json.dumps(context['new_nodes'], indent=2)} + + New Episode: + Content: {context['episode_content']} + Timestamp: {context['episode_timestamp']} + + Previous Episodes: + {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} + + Extract new semantic edges based on the content of the current episode, considering the existing graph structure, new nodes, and context from previous episodes. + + Guidelines: + 1. Create edges only between semantic nodes (not episodic nodes like messages). + 2. Each edge should represent a clear relationship between two semantic nodes. + 3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR). + 4. Provide a more detailed fact describing the relationship. + 5. If a relationship seems to update an existing one, create a new edge with the updated information. + 6. Consider temporal aspects of relationships when relevant. + 7. Do not create edges involving episodic nodes (like Message 1 or Message 2). + 8. Use existing nodes from the current graph structure when appropriate. + + Respond with a JSON object in the following format: + {{ + "new_edges": [ + {{ + "relation_type": "RELATION_TYPE_IN_CAPS", + "source_node": "Name of the source semantic node", + "target_node": "Name of the target semantic node", + "fact": "Detailed description of the relationship", + "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned", + "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned" + }} + ] + }} + + If no new edges need to be added, return an empty list for "new_edges". + """, + ), + ] + + +def v2(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that extracts graph edges from provided context.", + ), + Message( + role="user", + content=f""" + Given the following context, extract new edges (relationships) that need to be added to the knowledge graph: + Nodes: + {json.dumps(context['nodes'], indent=2)} + + New Episode: + Content: {context['episode_content']} + + Previous Episodes: + {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} + + Extract new entity edges based on the content of the current episode, the given nodes, and context from previous episodes. + + Guidelines: + 1. Create edges only between the provided nodes. + 2. Each edge should represent a clear relationship between two nodes. + 3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR). + 4. Provide a more detailed fact describing the relationship. + 5. Consider temporal aspects of relationships when relevant. + + Respond with a JSON object in the following format: + {{ + "edges": [ + {{ + "relation_type": "RELATION_TYPE_IN_CAPS", + "source_node_uuid": "uuid of the source entity node", + "target_node_uuid": "uuid of the target entity node", + "fact": "Detailed description of the relationship", + "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null if not explicitly mentioned", + "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null if ongoing or not explicitly mentioned" + }} + ] + }} + + If no new edges need to be added, return an empty list for "new_edges". + """, + ), + ] + + +versions: Versions = {"v1": v1, "v2": v2} diff --git a/core/prompts/extract_nodes.py b/core/prompts/extract_nodes.py index 1d171943..8e3b1f55 100644 --- a/core/prompts/extract_nodes.py +++ b/core/prompts/extract_nodes.py @@ -6,10 +6,12 @@ class Prompt(Protocol): v1: PromptVersion + v2: PromptVersion class Versions(TypedDict): v1: PromptFunction + v2: PromptFunction def v1(context: dict[str, any]) -> list[Message]: @@ -60,6 +62,45 @@ def v1(context: dict[str, any]) -> list[Message]: ] -versions: Versions = { - "v1": v1, -} +def v2(context: dict[str, any]) -> list[Message]: + return [ + Message( + role="system", + content="You are a helpful assistant that extracts graph nodes from provided context.", + ), + Message( + role="user", + content=f""" + Given the following context, extract new entity nodes that need to be added to the knowledge graph: + + Previous Episodes: + {json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} + + New Episode: + Content: {context["episode_content"]} + + Extract new entity nodes based on the content of the current episode, while considering the context from previous episodes. + + Guidelines: + 1. Focus on entities, concepts, or actors that are central to the current episode. + 2. Avoid creating nodes for relationships or actions (these will be handled as edges later). + 3. Provide a brief but informative summary for each node. + + Respond with a JSON object in the following format: + {{ + "new_nodes": [ + {{ + "name": "Unique identifier for the node", + "labels": ["Entity", "OptionalAdditionalLabel"], + "summary": "Brief summary of the node's role or significance" + }} + ] + }} + + If no new nodes need to be added, return an empty list for "new_nodes". + """, + ), + ] + + +versions: Versions = {"v1": v1, "v2": v2} diff --git a/core/prompts/lib.py b/core/prompts/lib.py index 7d47a650..20591a76 100644 --- a/core/prompts/lib.py +++ b/core/prompts/lib.py @@ -8,21 +8,37 @@ versions as extract_nodes_versions, ) +from .dedupe_nodes import ( + Prompt as DedupeNodesPrompt, + Versions as DedupeNodesVersions, + versions as dedupe_nodes_versions, +) + from .extract_edges import ( Prompt as ExtractEdgesPrompt, Versions as ExtractEdgesVersions, versions as extract_edges_versions, ) +from .dedupe_edges import ( + Prompt as DedupeEdgesPrompt, + Versions as DedupeEdgesVersions, + versions as dedupe_edges_versions, +) + class PromptLibrary(Protocol): extract_nodes: ExtractNodesPrompt + dedupe_nodes: DedupeNodesPrompt extract_edges: ExtractEdgesPrompt + dedupe_edges: DedupeEdgesPrompt class PromptLibraryImpl(TypedDict): extract_nodes: ExtractNodesVersions + dedupe_nodes: DedupeNodesVersions extract_edges: ExtractEdgesVersions + dedupe_edges: DedupeEdgesVersions class VersionWrapper: @@ -47,7 +63,9 @@ def __init__(self, library: PromptLibraryImpl): PROMPT_LIBRARY_IMPL: PromptLibraryImpl = { "extract_nodes": extract_nodes_versions, + "dedupe_nodes": dedupe_nodes_versions, "extract_edges": extract_edges_versions, + "dedupe_edges": dedupe_edges_versions, } prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) diff --git a/core/utils.py b/core/utils.py deleted file mode 100644 index 72ed57d8..00000000 --- a/core/utils.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import Tuple - -from core.edges import EpisodicEdge, EntityEdge, Edge -from core.nodes import EntityNode, EpisodicNode, Node - - -async def bfs( - nodes: list[Node], edges: list[Edge], k: int -) -> Tuple[list[EntityNode], list[EntityEdge]]: ... - - -# Breadth first search over nodes and edges with desired depth - - -async def similarity_search( - query: str, embedder -) -> Tuple[list[EntityNode], list[EntityEdge]]: ... - - -# vector similarity search over embedded facts - - -async def fulltext_search( - query: str, -) -> Tuple[list[EntityNode], list[EntityEdge]]: ... - - -# fulltext search over names and summary - - -def build_episodic_edges( - entity_nodes: list[EntityNode], episode: EpisodicNode -) -> list[EpisodicEdge]: - edges: list[EpisodicEdge] = [] - - for node in entity_nodes: - edges.append( - EpisodicEdge( - source_node=episode, - target_node=node, - created_at=episode.created_at, - ) - ) - - return edges diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index 0fb27c59..47e004cc 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -13,15 +13,17 @@ def build_episodic_edges( - semantic_nodes: List[EntityNode], + entity_nodes: List[EntityNode], episode: EpisodicNode, transaction_from: datetime, ) -> List[EpisodicEdge]: edges: List[EpisodicEdge] = [] - for node in semantic_nodes: + for node in entity_nodes: edge = EpisodicEdge( - source_node=episode, target_node=node, created_at=transaction_from + source_node_uuid=episode.uuid, + target_node_uuid=node.uuid, + created_at=transaction_from, ) edges.append(edge) @@ -132,3 +134,94 @@ async def extract_new_edges( affected_nodes.add(edge.source_node) affected_nodes.add(edge.target_node) return new_edges, list(affected_nodes) + + +async def extract_edges( + llm_client: LLMClient, + episode: EpisodicNode, + nodes: list[EntityNode], + previous_episodes: list[EpisodicNode], +) -> list[EntityEdge]: + # Prepare context for LLM + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "nodes": [ + {"uuid": node.uuid, "name": node.name, "summary": node.summary} + for node in nodes + ], + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_edges.v2(context) + ) + edges_data = llm_response.get("edges", []) + logger.info(f"Extracted new edges: {edges_data}") + + # Convert the extracted data into EntityEdge objects + edges = [] + for edge_data in edges_data: + edge = EntityEdge( + source_node_uuid=edge_data["source_node_uuid"], + target_node_uuid=edge_data["target_node_uuid"], + name=edge_data["relation_type"], + fact=edge_data["fact"], + episodes=[episode.uuid], + created_at=datetime.now(), + valid_at=edge_data["valid_at"], + invalid_at=edge_data["invalid_at"], + ) + edges.append(edge) + logger.info( + f"Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})" + ) + + return edges + + +async def dedupe_extracted_edges( + llm_client: LLMClient, + extracted_edges: list[EntityEdge], + existing_edges: list[EntityEdge], +) -> list[EntityEdge]: + # Create edge map + edge_map = {} + for edge in existing_edges: + edge_map[edge.name] = edge + for edge in extracted_edges: + if edge.name in edge_map.keys(): + continue + edge_map[edge.name] = edge + + # Prepare context for LLM + context = { + "extracted_edges": [ + {"name": edge.name, "fact": edge.fact} for edge in extracted_edges + ], + "existing_edges": [ + {"name": edge.name, "fact": edge.fact} for edge in extracted_edges + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.dedupe_edges.v1(context) + ) + new_edges_data = llm_response.get("new_edges", []) + logger.info(f"Extracted new edges: {new_edges_data}") + + # Get full edge data + edges = [] + for edge_data in new_edges_data: + edge = edge_map[edge_data["name"]] + edges.append(edge) + + return edges diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index 4c184f90..571c3316 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -61,3 +61,87 @@ async def extract_new_nodes( logger.info(f"Node {node_data['name']} already exists, skipping creation.") return new_nodes + + +async def extract_nodes( + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], +) -> list[EntityNode]: + # Prepare context for LLM + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.v2(context) + ) + new_nodes_data = llm_response.get("new_nodes", []) + logger.info(f"Extracted new nodes: {new_nodes_data}") + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + new_node = EntityNode( + name=node_data["name"], + labels=node_data["labels"], + summary=node_data["summary"], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") + + return new_nodes + + +async def dedupe_extracted_nodes( + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], +) -> list[EntityNode]: + # build node map + node_map = {} + for node in existing_nodes: + node_map[node.name] = node + for node in extracted_nodes: + if node.name in node_map.keys(): + continue + node_map[node.name] = node + + # Prepare context for LLM + existing_nodes_context = [ + {"name": node.name, "summary": node.summary} for node in existing_nodes + ] + + extracted_nodes_context = [ + {"name": node.name, "summary": node.summary} for node in extracted_nodes + ] + + context = { + "existing_nodes": existing_nodes_context, + "extracted_nodes": extracted_nodes_context, + } + + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.v1(context) + ) + + new_nodes_data = llm_response.get("new_nodes", []) + logger.info(f"Deduplicated nodes: {new_nodes_data}") + + # Get full node data + nodes = [] + for node_data in new_nodes_data: + node = node_map[node_data["name"]] + nodes.append(node) + + return nodes diff --git a/core/utils/search/__init__.py b/core/utils/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/core/utils/search/search_utils.py b/core/utils/search/search_utils.py new file mode 100644 index 00000000..25ca6e6c --- /dev/null +++ b/core/utils/search/search_utils.py @@ -0,0 +1,274 @@ +import asyncio +import logging +from datetime import datetime + +from neo4j import AsyncDriver + +from core.edges import EntityEdge +from core.nodes import EntityNode + +logger = logging.getLogger(__name__) + + +async def bfs(node_ids: list[str], driver: AsyncDriver): + records, _, _ = await driver.execute_query( + """ + MATCH (n WHERE n.uuid in $node_ids)-[r]->(m) + RETURN + n.uuid AS source_node_uuid, + n.name AS source_name, + n.summary AS source_summary, + m.uuid AS target_node_uuid, + m.name AS target_name, + m.summary AS target_summary, + r.uuid AS uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + + """, + node_ids=node_ids, + ) + + context = {} + + for record in records: + n_uuid = record["source_node_uuid"] + if n_uuid in context.keys(): + context[n_uuid]["facts"].append(record["fact"]) + else: + context[n_uuid] = { + "name": record["source_name"], + "summary": record["source_summary"], + "facts": [record["fact"]], + } + + m_uuid = record["target_node_uuid"] + if m_uuid not in context: + context[m_uuid] = { + "name": record["target_name"], + "summary": record["target_summary"], + "facts": [], + } + logger.info(f"bfs search returned context: {context}") + return context + + +async def edge_similarity_search( + search_vector: list[float], driver: AsyncDriver +) -> list[EntityEdge]: + # vector similarity search over embedded facts + records, _, _ = await driver.execute_query( + """ + CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector) + YIELD relationship AS r, score + MATCH (n)-[r:RELATES_TO]->(m) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC LIMIT 10 + """, + search_vector=search_vector, + ) + + edges: list[EntityEdge] = [] + + now = datetime.now() + + for record in records: + edge = EntityEdge( + uuid=record["uuid"], + source_node_uuid=record["source_node_uuid"], + target_node_uuid=record["target_node_uuid"], + fact=record["fact"], + name=record["name"], + episodes=record["episodes"], + fact_embedding=record["fact_embedding"], + created_at=now, + expired_at=now, + valid_at=now, + invalid_At=now, + ) + + edges.append(edge) + + logger.info(f"similarity search results. RESULT: {[edge.uuid for edge in edges]}") + + return edges + + +async def entity_similarity_search( + search_vector: list[float], driver: AsyncDriver +) -> list[EntityNode]: + # vector similarity search over entity names + records, _, _ = await driver.execute_query( + """ + CALL db.index.vector.queryNodes("name_embedding", 5, $search_vector) + YIELD node AS n, score + RETURN + n.uuid As uuid, + n.name AS name, + n.created_at AS created_at, + n.summary AS summary + ORDER BY score DESC + """, + search_vector=search_vector, + ) + nodes: list[EntityNode] = [] + + for record in records: + nodes.append( + EntityNode( + uuid=record["uuid"], + name=record["name"], + labels=[], + created_at=datetime.now(), + summary=record["summary"], + ) + ) + + logger.info(f"name semantic search results. RESULT: {nodes}") + + return nodes + + +async def entity_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityNode]: + # BM25 search to get top nodes + fuzzy_query = query + "~" + records, _, _ = await driver.execute_query( + """ + CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score + RETURN + node.uuid As uuid, + node.name AS name, + node.created_at AS created_at, + node.summary AS summary + ORDER BY score DESC + LIMIT 10 + """, + query=fuzzy_query, + ) + nodes: list[EntityNode] = [] + + for record in records: + nodes.append( + EntityNode( + uuid=record["uuid"], + name=record["name"], + labels=[], + created_at=datetime.now(), + summary=record["summary"], + ) + ) + + logger.info(f"fulltext search results. QUERY:{query}. RESULT: {nodes}") + + return nodes + + +async def edge_fulltext_search(query: str, driver: AsyncDriver) -> list[EntityEdge]: + # fulltext search over facts + fuzzy_query = query + "~" + + records, _, _ = await driver.execute_query( + """ + CALL db.index.fulltext.queryRelationships("name_and_fact", $query) + YIELD relationship AS r, score + MATCH (n:Entity)-[r]->(m:Entity) + RETURN + r.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + r.created_at AS created_at, + r.name AS name, + r.fact AS fact, + r.fact_embedding AS fact_embedding, + r.episodes AS episodes, + r.expired_at AS expired_at, + r.valid_at AS valid_at, + r.invalid_at AS invalid_at + ORDER BY score DESC LIMIT 10 + """, + query=fuzzy_query, + ) + + edges: list[EntityEdge] = [] + + now = datetime.now() + + for record in records: + edge = EntityEdge( + uuid=record["uuid"], + source_node_uuid=record["source_node_uuid"], + target_node_uuid=record["target_node_uuid"], + fact=record["fact"], + name=record["name"], + episodes=record["episodes"], + fact_embedding=record["fact_embedding"], + created_at=now, + expired_at=now, + valid_at=now, + invalid_At=now, + ) + + edges.append(edge) + + logger.info( + f"similarity search results. QUERY:{query}. RESULT: {[edge.uuid for edge in edges]}" + ) + + return edges + + +async def get_relevant_nodes( + nodes: list[EntityNode], + driver: AsyncDriver, +) -> list[EntityNode]: + relevant_nodes: dict[str, EntityNode] = {} + + results = await asyncio.gather( + *[entity_fulltext_search(node.name, driver) for node in nodes], + *[entity_similarity_search(node.name_embedding, driver) for node in nodes], + ) + + for result in results: + for node in result: + relevant_nodes[node.uuid] = node + + logger.info(f"Found relevant nodes: {relevant_nodes.keys()}") + + return relevant_nodes.values() + + +async def get_relevant_edges( + edges: list[EntityEdge], + driver: AsyncDriver, +) -> list[EntityEdge]: + relevant_edges: dict[str, EntityEdge] = {} + + results = await asyncio.gather( + *[edge_similarity_search(edge.fact_embedding, driver) for edge in edges], + *[edge_fulltext_search(edge.fact, driver) for edge in edges], + ) + + for result in results: + for edge in result: + relevant_edges[edge.uuid] = edge + + logger.info(f"Found relevant nodes: {relevant_edges.keys()}") + + return relevant_edges.values() diff --git a/core/utils/utils.py b/core/utils/utils.py new file mode 100644 index 00000000..43746600 --- /dev/null +++ b/core/utils/utils.py @@ -0,0 +1,25 @@ +import logging + +from neo4j import AsyncDriver + +from core.edges import EpisodicEdge, EntityEdge, Edge +from core.nodes import EntityNode, EpisodicNode, Node + +logger = logging.getLogger(__name__) + + +def build_episodic_edges( + entity_nodes: list[EntityNode], episode: EpisodicNode +) -> list[EpisodicEdge]: + edges: list[EpisodicEdge] = [] + + for node in entity_nodes: + edges.append( + EpisodicEdge( + source_node_uuid=episode, + target_node_uuid=node, + created_at=episode.created_at, + ) + ) + + return edges diff --git a/tests/graphiti_int_tests.py b/tests/graphiti_int_tests.py index 1915c81f..f8319ea4 100644 --- a/tests/graphiti_int_tests.py +++ b/tests/graphiti_int_tests.py @@ -1,3 +1,5 @@ +import logging +import sys import os import pytest @@ -9,9 +11,11 @@ from core.edges import EpisodicEdge, EntityEdge from core.graphiti import Graphiti +from core.llm_client.config import EMBEDDING_DIM from core.nodes import EpisodicNode, EntityNode from datetime import datetime + pytest_plugins = ("pytest_asyncio",) load_dotenv() @@ -21,10 +25,59 @@ NEO4j_PASSWORD = os.getenv("NEO4J_PASSWORD") +def setup_logging(): + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO + + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + # Add formatter to console handler + console_handler.setFormatter(formatter) + + # Add console handler to logger + logger.addHandler(console_handler) + + return logger + + +def format_context(context): + formatted_string = "" + for uuid, data in context.items(): + formatted_string += f"UUID: {uuid}\n" + formatted_string += f" Name: {data['name']}\n" + formatted_string += f" Summary: {data['summary']}\n" + formatted_string += " Facts:\n" + for fact in data["facts"]: + formatted_string += f" - {fact}\n" + formatted_string += "\n" + return formatted_string.strip() + + @pytest.mark.asyncio async def test_graphiti_init(): + logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) await graphiti.build_indices() + + context = await graphiti.search("Freakenomics guest") + + logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context)) + + context = await graphiti.search("tania tetlow") + + logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context)) + + context = await graphiti.search("issues with higher ed") + + logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context)) graphiti.close() @@ -57,16 +110,16 @@ async def test_graph_integration(): bob_node = EntityNode(name="Bob", labels=[], created_at=now, summary="Bob summary") episodic_edge_1 = EpisodicEdge( - source_node=episode, target_node=alice_node, created_at=now + source_node_uuid=episode, target_node_uuid=alice_node, created_at=now ) episodic_edge_2 = EpisodicEdge( - source_node=episode, target_node=bob_node, created_at=now + source_node_uuid=episode, target_node_uuid=bob_node, created_at=now ) entity_edge = EntityEdge( - source_node=alice_node, - target_node=bob_node, + source_node_uuid=alice_node.uuid, + target_node_uuid=bob_node.uuid, created_at=now, name="likes", fact="Alice likes Bob",