From 80612c4e1395885ac55f9bc37516e2520960bdae Mon Sep 17 00:00:00 2001 From: paulpaliychuk Date: Thu, 22 Aug 2024 17:48:08 -0400 Subject: [PATCH] fix: Linter errors --- core/graphiti.py | 465 +++++++++--------- core/utils/maintenance/edge_operations.py | 5 +- core/utils/maintenance/node_operations.py | 295 +++++------ core/utils/maintenance/temporal_operations.py | 106 ++-- core/utils/search/search_utils.py | 5 +- runner.py | 40 +- 6 files changed, 474 insertions(+), 442 deletions(-) diff --git a/core/graphiti.py b/core/graphiti.py index dde4791..28cb92c 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -13,37 +13,33 @@ from core.nodes import EntityNode, EpisodicNode from core.search.search import SearchConfig, hybrid_search from core.search.search_utils import ( - get_relevant_edges, - get_relevant_nodes, + get_relevant_edges, + get_relevant_nodes, ) from core.utils import ( - build_episodic_edges, - retrieve_episodes, + build_episodic_edges, + retrieve_episodes, ) from core.utils.bulk_utils import ( - BulkEpisode, - dedupe_edges_bulk, - dedupe_nodes_bulk, - extract_nodes_and_edges_bulk, - resolve_edge_pointers, - retrieve_previous_episodes_bulk, -) -from core.utils.maintenance.edge_operations import dedupe_extracted_edges, extract_edges -from core.utils.maintenance.graph_data_operations import ( - EPISODE_WINDOW_LEN, - build_indices_and_constraints, + BulkEpisode, + dedupe_edges_bulk, + dedupe_nodes_bulk, + extract_nodes_and_edges_bulk, + resolve_edge_pointers, + retrieve_previous_episodes_bulk, ) from core.utils.maintenance.edge_operations import ( - extract_edges, - dedupe_extracted_edges_v2, dedupe_extracted_edges, + extract_edges, +) +from core.utils.maintenance.graph_data_operations import ( + EPISODE_WINDOW_LEN, + build_indices_and_constraints, ) -from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( - invalidate_edges, - prepare_edges_for_invalidation, - extract_node_and_edge_triplets, + invalidate_edges, + prepare_edges_for_invalidation, ) logger = logging.getLogger(__name__) @@ -52,82 +48,86 @@ class Graphiti: - def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): - self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) - self.database = 'neo4j' - if llm_client: - self.llm_client = llm_client - else: - self.llm_client = OpenAIClient( - LLMConfig( - api_key=os.getenv('OPENAI_API_KEY'), - model='gpt-4o-mini', - base_url='https://api.openai.com/v1', - ) - ) - - def close(self): - self.driver.close() - - async def build_indices_and_constraints(self): - await build_indices_and_constraints(self.driver) - - async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - sources: list[str] | None = 'messages', - ) -> list[EpisodicNode]: - """Retrieve the last n episodic nodes from the graph""" - return await retrieve_episodes(self.driver, reference_time, last_n, sources) - - # Invalidate edges that are no longer valid - async def invalidate_edges( - self, - episode: EpisodicNode, - new_nodes: list[EntityNode], - new_edges: list[EntityEdge], - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], - ): ... - - async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime | None = None, - episode_type: str | None = 'string', # TODO: this field isn't used yet? - success_callback: Callable | None = None, - error_callback: Callable | None = None, - ): - """Process an episode and update the graph""" - try: - start = time() - - nodes: list[EntityNode] = [] - entity_edges: list[EntityEdge] = [] - episodic_edges: list[EpisodicEdge] = [] - embedder = self.llm_client.client.embeddings - now = datetime.now() - - previous_episodes = await self.retrieve_episodes(reference_time) - episode = EpisodicNode( - name=name, - labels=[], - source='messages', - content=episode_body, - source_description=source_description, - created_at=now, - valid_at=reference_time, - ) - - extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) + def __init__( + self, uri: str, user: str, password: str, llm_client: LLMClient | None = None + ): + self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) + self.database = "neo4j" + if llm_client: + self.llm_client = llm_client + else: + self.llm_client = OpenAIClient( + LLMConfig( + api_key=os.getenv("OPENAI_API_KEY"), + model="gpt-4o-mini", + base_url="https://api.openai.com/v1", + ) + ) + + def close(self): + self.driver.close() + + async def build_indices_and_constraints(self): + await build_indices_and_constraints(self.driver) + + async def retrieve_episodes( + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + sources: list[str] | None = "messages", + ) -> list[EpisodicNode]: + """Retrieve the last n episodic nodes from the graph""" + return await retrieve_episodes(self.driver, reference_time, last_n, sources) + + # Invalidate edges that are no longer valid + async def invalidate_edges( + self, + episode: EpisodicNode, + new_nodes: list[EntityNode], + new_edges: list[EntityEdge], + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], + ): ... + + async def add_episode( + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime | None = None, + episode_type: str | None = "string", # TODO: this field isn't used yet? + success_callback: Callable | None = None, + error_callback: Callable | None = None, + ): + """Process an episode and update the graph""" + try: + start = time() + + nodes: list[EntityNode] = [] + entity_edges: list[EntityEdge] = [] + episodic_edges: list[EpisodicEdge] = [] + embedder = self.llm_client.client.embeddings + now = datetime.now() + + previous_episodes = await self.retrieve_episodes(reference_time) + episode = EpisodicNode( + name=name, + labels=[], + source="messages", + content=episode_body, + source_description=source_description, + created_at=now, + valid_at=reference_time, + ) + + extracted_nodes = await extract_nodes( + self.llm_client, episode, previous_episodes + ) logger.info( f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}" ) - # Calculate Embeddings + # Calculate Embeddings await asyncio.gather( *[node.generate_name_embedding(embedder) for node in extracted_nodes] @@ -144,17 +144,21 @@ async def add_episode( ) nodes.extend(touched_nodes) - extracted_edges = await extract_edges( - self.llm_client, episode, touched_nodes, previous_episodes - ) + extracted_edges = await extract_edges( + self.llm_client, episode, touched_nodes, previous_episodes + ) - await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) + await asyncio.gather( + *[edge.generate_embedding(embedder) for edge in extracted_edges] + ) - existing_edges = await get_relevant_edges(extracted_edges, self.driver) - logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}') - logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') + existing_edges = await get_relevant_edges(extracted_edges, self.driver) + logger.info(f"Existing edges: {[(e.name, e.uuid) for e in existing_edges]}") + logger.info( + f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}" + ) - # deduped_edges = await dedupe_extracted_edges_v2( + # deduped_edges = await dedupe_extracted_edges_v2( # self.llm_client, # extract_node_and_edge_triplets(extracted_edges, nodes), # extract_node_and_edge_triplets(existing_edges, nodes), @@ -171,12 +175,12 @@ async def add_episode( edge_touched_node_uuids.append(edge.source_node_uuid) edge_touched_node_uuids.append(edge.target_node_uuid) - ( - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, - ) = prepare_edges_for_invalidation( - existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes - ) + ( + old_edges_with_nodes_pending_invalidation, + new_edges_with_nodes, + ) = prepare_edges_for_invalidation( + existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes + ) invalidated_edges = await invalidate_edges( self.llm_client, @@ -190,9 +194,9 @@ async def add_episode( edge_touched_node_uuids.append(edge.source_node_uuid) edge_touched_node_uuids.append(edge.target_node_uuid) - entity_edges.extend(invalidated_edges) + entity_edges.extend(invalidated_edges) - edge_touched_node_uuids = list(set(edge_touched_node_uuids)) + edge_touched_node_uuids = list(set(edge_touched_node_uuids)) involved_nodes = [ node for node in nodes if node.uuid in edge_touched_node_uuids ] @@ -202,10 +206,11 @@ async def add_episode( ) logger.info( - f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') + f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}" + ) - logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') - entity_edges.extend(deduped_edges) + logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}") + entity_edges.extend(deduped_edges) episodic_edges.extend( build_episodic_edges( @@ -218,117 +223,125 @@ async def add_episode( # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built logger.info(f"Built episodic edges: {episodic_edges}") - # invalidated_edges = await self.invalidate_edges( - # episode, new_nodes, new_edges, relevant_schema, previous_episodes - # ) - - # edges.extend(invalidated_edges) - - # Future optimization would be using batch operations to save nodes and edges - await episode.save(self.driver) - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) - await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) - - end = time() - logger.info(f'Completed add_episode in {(end-start) * 1000} ms') - # for node in nodes: - # if isinstance(node, EntityNode): - # await node.update_summary(self.driver) - if success_callback: - await success_callback(episode) - except Exception as e: - if error_callback: - await error_callback(episode, e) - else: - raise e - - async def add_episode_bulk( - self, - bulk_episodes: list[BulkEpisode], - ): - try: - start = time() - embedder = self.llm_client.client.embeddings - now = datetime.now() - - episodes = [ - EpisodicNode( - name=episode.name, - labels=[], - source='messages', - content=episode.content, - source_description=episode.source_description, - created_at=now, - valid_at=episode.reference_time, - ) - for episode in bulk_episodes - ] - - # Save all the episodes - await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) - - # Get previous episode context for each episode - episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) - - # Extract all nodes and edges - ( - extracted_nodes, - extracted_edges, - episodic_edges, - ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) - - # Generate embeddings - await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes], - *[edge.generate_embedding(embedder) for edge in extracted_edges], - ) - - # Dedupe extracted nodes - nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes) - - # save nodes to KG - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - - # re-map edge pointers so that they don't point to discard dupe nodes - extracted_edges: list[EntityEdge] = resolve_edge_pointers(extracted_edges, uuid_map) - episodic_edges: list[EpisodicEdge] = resolve_edge_pointers(episodic_edges, uuid_map) - - # save episodic edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) - - # Dedupe extracted edges - edges = await dedupe_edges_bulk(self.driver, self.llm_client, extracted_edges) - logger.info(f'extracted edge length: {len(edges)}') - - # invalidate edges - - # save edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in edges]) - - end = time() - logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms') - - except Exception as e: - raise e - - async def search(self, query: str, num_results=10): - search_config = SearchConfig(num_episodes=0, num_results=num_results) - edges = ( - await hybrid_search( - self.driver, - self.llm_client.client.embeddings, - query, - datetime.now(), - search_config, - ) - )['edges'] - - facts = [edge.fact for edge in edges] - - return facts - - async def _search(self, query: str, timestamp: datetime, config: SearchConfig): - return await hybrid_search( - self.driver, self.llm_client.client.embeddings, query, timestamp, config - ) + # invalidated_edges = await self.invalidate_edges( + # episode, new_nodes, new_edges, relevant_schema, previous_episodes + # ) + + # edges.extend(invalidated_edges) + + # Future optimization would be using batch operations to save nodes and edges + await episode.save(self.driver) + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) + + end = time() + logger.info(f"Completed add_episode in {(end-start) * 1000} ms") + # for node in nodes: + # if isinstance(node, EntityNode): + # await node.update_summary(self.driver) + if success_callback: + await success_callback(episode) + except Exception as e: + if error_callback: + await error_callback(episode, e) + else: + raise e + + async def add_episode_bulk( + self, + bulk_episodes: list[BulkEpisode], + ): + try: + start = time() + embedder = self.llm_client.client.embeddings + now = datetime.now() + + episodes = [ + EpisodicNode( + name=episode.name, + labels=[], + source="messages", + content=episode.content, + source_description=episode.source_description, + created_at=now, + valid_at=episode.reference_time, + ) + for episode in bulk_episodes + ] + + # Save all the episodes + await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) + + # Get previous episode context for each episode + episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) + + # Extract all nodes and edges + ( + extracted_nodes, + extracted_edges, + episodic_edges, + ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) + + # Generate embeddings + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes], + *[edge.generate_embedding(embedder) for edge in extracted_edges], + ) + + # Dedupe extracted nodes + nodes, uuid_map = await dedupe_nodes_bulk( + self.driver, self.llm_client, extracted_nodes + ) + + # save nodes to KG + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + + # re-map edge pointers so that they don't point to discard dupe nodes + extracted_edges: list[EntityEdge] = resolve_edge_pointers( + extracted_edges, uuid_map + ) + episodic_edges: list[EpisodicEdge] = resolve_edge_pointers( + episodic_edges, uuid_map + ) + + # save episodic edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + + # Dedupe extracted edges + edges = await dedupe_edges_bulk( + self.driver, self.llm_client, extracted_edges + ) + logger.info(f"extracted edge length: {len(edges)}") + + # invalidate edges + + # save edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + + end = time() + logger.info(f"Completed add_episode_bulk in {(end-start) * 1000} ms") + + except Exception as e: + raise e + + async def search(self, query: str, num_results=10): + search_config = SearchConfig(num_episodes=0, num_results=num_results) + edges = ( + await hybrid_search( + self.driver, + self.llm_client.client.embeddings, + query, + datetime.now(), + search_config, + ) + )["edges"] + + facts = [edge.fact for edge in edges] + + return facts + + async def _search(self, query: str, timestamp: datetime, config: SearchConfig): + return await hybrid_search( + self.driver, self.llm_client.client.embeddings, query, timestamp, config + ) diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index 614e9f3..dede92d 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -7,9 +7,8 @@ from core.edges import EntityEdge, EpisodicEdge from core.llm_client import LLMClient from core.nodes import EntityNode, EpisodicNode -from core.edges import EpisodicEdge, EntityEdge -from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet from core.prompts import prompt_library +from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet logger = logging.getLogger(__name__) @@ -197,7 +196,7 @@ async def dedupe_extracted_edges_v2( for n1, edge, n2 in existing_edges: edge_map[create_edge_identifier(n1, edge, n2)] = edge for n1, edge, n2 in extracted_edges: - if create_edge_identifier(n1, edge, n2) in edge_map.keys(): + if create_edge_identifier(n1, edge, n2) in edge_map: continue edge_map[create_edge_identifier(n1, edge, n2)] = edge diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index 6c26770..ab8c34f 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -10,192 +10,205 @@ async def extract_new_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - relevant_schema: dict[str, any], - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + relevant_schema: dict[str, any], + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: - # Prepare context for LLM - existing_nodes = [ - {'name': node_name, 'label': node_info['label'], 'uuid': node_info['uuid']} - for node_name, node_info in relevant_schema['nodes'].items() - ] - - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'existing_nodes': existing_nodes, - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v1(context)) - new_nodes_data = llm_response.get('new_nodes', []) - logger.info(f'Extracted new nodes: {new_nodes_data}') - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - # Check if the node already exists - if not any(existing_node['name'] == node_data['name'] for existing_node in existing_nodes): - new_node = EntityNode( - name=node_data['name'], - labels=node_data['labels'], - summary=node_data['summary'], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - else: - logger.info(f"Node {node_data['name']} already exists, skipping creation.") - - return new_nodes + # Prepare context for LLM + existing_nodes = [ + {"name": node_name, "label": node_info["label"], "uuid": node_info["uuid"]} + for node_name, node_info in relevant_schema["nodes"].items() + ] + + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "existing_nodes": existing_nodes, + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.v1(context) + ) + new_nodes_data = llm_response.get("new_nodes", []) + logger.info(f"Extracted new nodes: {new_nodes_data}") + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + # Check if the node already exists + if not any( + existing_node["name"] == node_data["name"] + for existing_node in existing_nodes + ): + new_node = EntityNode( + name=node_data["name"], + labels=node_data["labels"], + summary=node_data["summary"], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") + else: + logger.info(f"Node {node_data['name']} already exists, skipping creation.") + + return new_nodes async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: - start = time() - - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v3(context)) - new_nodes_data = llm_response.get('new_nodes', []) - - end = time() - logger.info(f'Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms') - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - new_node = EntityNode( - name=node_data['name'], - labels=node_data['labels'], - summary=node_data['summary'], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - - return new_nodes + start = time() + + # Prepare context for LLM + context = { + "episode_content": episode.content, + "episode_timestamp": ( + episode.valid_at.isoformat() if episode.valid_at else None + ), + "previous_episodes": [ + { + "content": ep.content, + "timestamp": ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.v3(context) + ) + new_nodes_data = llm_response.get("new_nodes", []) + + end = time() + logger.info(f"Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms") + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + new_node = EntityNode( + name=node_data["name"], + labels=node_data["labels"], + summary=node_data["summary"], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f"Created new node: {new_node.name} (UUID: {new_node.uuid})") + + return new_nodes async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - start = time() + start = time() - # build existing node map - node_map = {} - for node in existing_nodes: - node_map[node.name] = node + # build existing node map + node_map = {} + for node in existing_nodes: + node_map[node.name] = node - # Temp hack + # Temp hack new_nodes_map = {} for node in extracted_nodes: new_nodes_map[node.name] = node # Prepare context for LLM - existing_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in existing_nodes - ] + existing_nodes_context = [ + {"name": node.name, "summary": node.summary} for node in existing_nodes + ] - extracted_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in extracted_nodes - ] + extracted_nodes_context = [ + {"name": node.name, "summary": node.summary} for node in extracted_nodes + ] - context = { - 'existing_nodes': existing_nodes_context, - 'extracted_nodes': extracted_nodes_context, - } + context = { + "existing_nodes": existing_nodes_context, + "extracted_nodes": extracted_nodes_context, + } - llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v2(context)) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.v2(context) + ) - duplicate_data = llm_response.get('duplicates', []) + duplicate_data = llm_response.get("duplicates", []) - end = time() - logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms') + end = time() + logger.info(f"Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms") - uuid_map = {} - for duplicate in duplicate_data: - uuid = new_nodes_map[duplicate['name']].uuid - uuid_value = node_map[duplicate['duplicate_of']].uuid - uuid_map[uuid] = uuid_value + uuid_map = {} + for duplicate in duplicate_data: + uuid = new_nodes_map[duplicate["name"]].uuid + uuid_value = node_map[duplicate["duplicate_of"]].uuid + uuid_map[uuid] = uuid_value - nodes = [] - brand_new_nodes = [] + nodes = [] + brand_new_nodes = [] for node in extracted_nodes: if node.uuid in uuid_map: existing_uuid = uuid_map[node.uuid] # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) - existing_node = next( + existing_node = next( (v for k, v in node_map.items() if v.uuid == existing_uuid), None ) nodes.append(existing_node) continue brand_new_nodes.append(node) - nodes.append(node) + nodes.append(node) - return nodes, uuid_map, brand_new_nodes + return nodes, uuid_map, brand_new_nodes async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - start = time() + start = time() - # build node map - node_map = {} - for node in nodes: - node_map[node.name] = node + # build node map + node_map = {} + for node in nodes: + node_map[node.name] = node - # Prepare context for LLM - nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes] + # Prepare context for LLM + nodes_context = [{"name": node.name, "summary": node.summary} for node in nodes] - context = { - 'nodes': nodes_context, - } + context = { + "nodes": nodes_context, + } - llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.node_list(context) - ) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.node_list(context) + ) - nodes_data = llm_response.get('nodes', []) + nodes_data = llm_response.get("nodes", []) - end = time() - logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms') + end = time() + logger.info(f"Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms") - # Get full node data - unique_nodes = [] - uuid_map: dict[str, str] = {} - for node_data in nodes_data: - node = node_map[node_data['names'][0]] - unique_nodes.append(node) + # Get full node data + unique_nodes = [] + uuid_map: dict[str, str] = {} + for node_data in nodes_data: + node = node_map[node_data["names"][0]] + unique_nodes.append(node) - for name in node_data['names'][1:]: - uuid = node_map[name].uuid - uuid_value = node_map[node_data['names'][0]].uuid - uuid_map[uuid] = uuid_value + for name in node_data["names"][1:]: + uuid = node_map[name].uuid + uuid_value = node_map[node_data["names"][0]].uuid + uuid_map[uuid] = uuid_value - return unique_nodes, uuid_map + return unique_nodes, uuid_map diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index a89cb0b..87b2fd4 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -31,38 +31,42 @@ def extract_node_edge_node_triplet( def prepare_edges_for_invalidation( - existing_edges: list[EntityEdge], - new_edges: list[EntityEdge], - nodes: list[EntityNode], + existing_edges: list[EntityEdge], + new_edges: list[EntityEdge], + nodes: list[EntityNode], ) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]: - existing_edges_pending_invalidation = [] # TODO: this is not yet used? - new_edges_with_nodes = [] # TODO: this is not yet used? - - existing_edges_pending_invalidation = [] - new_edges_with_nodes = [] - - for edge_list, result_list in [ - (existing_edges, existing_edges_pending_invalidation), - (new_edges, new_edges_with_nodes), - ]: - for edge in edge_list: - source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) - target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) + existing_edges_pending_invalidation = [] # TODO: this is not yet used? + new_edges_with_nodes = [] # TODO: this is not yet used? + + existing_edges_pending_invalidation = [] + new_edges_with_nodes = [] + + for edge_list, result_list in [ + (existing_edges, existing_edges_pending_invalidation), + (new_edges, new_edges_with_nodes), + ]: + for edge in edge_list: + source_node = next( + (node for node in nodes if node.uuid == edge.source_node_uuid), None + ) + target_node = next( + (node for node in nodes if node.uuid == edge.target_node_uuid), None + ) - if source_node and target_node: - result_list.append((source_node, edge, target_node)) + if source_node and target_node: + result_list.append((source_node, edge, target_node)) - return existing_edges_pending_invalidation, new_edges_with_nodes + return existing_edges_pending_invalidation, new_edges_with_nodes async def invalidate_edges( - llm_client: LLMClient, - existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], + llm_client: LLMClient, + existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], current_episode: EpisodicNode, previous_episodes: list[EpisodicNode], ) -> list[EntityEdge]: - invalidated_edges = [] # TODO: this is not yet used? + invalidated_edges = [] # TODO: this is not yet used? context = prepare_invalidation_context( existing_edges_pending_invalidation, @@ -76,29 +80,29 @@ async def invalidate_edges( ) logger.info(f"invalidate_edges LLM response: {llm_response}") - edges_to_invalidate = llm_response.get('invalidated_edges', []) - invalidated_edges = process_edge_invalidation_llm_response( - edges_to_invalidate, existing_edges_pending_invalidation - ) + edges_to_invalidate = llm_response.get("invalidated_edges", []) + invalidated_edges = process_edge_invalidation_llm_response( + edges_to_invalidate, existing_edges_pending_invalidation + ) - return invalidated_edges + return invalidated_edges def prepare_invalidation_context( - existing_edges: list[NodeEdgeNodeTriplet], + existing_edges: list[NodeEdgeNodeTriplet], new_edges: list[NodeEdgeNodeTriplet], current_episode: EpisodicNode, previous_episodes: list[EpisodicNode], ) -> dict: - return { - 'existing_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' - for source_node, edge, target_node in sorted( - existing_edges, key=lambda x: x[1].created_at, reverse=True - ) - ], - 'new_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' + return { + "existing_edges": [ + f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" + for source_node, edge, target_node in sorted( + existing_edges, key=lambda x: x[1].created_at, reverse=True + ) + ], + "new_edges": [ + f"{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})" for source_node, edge, target_node in sorted( new_edges, key=lambda x: x[1].created_at, reverse=True ) @@ -109,20 +113,20 @@ def prepare_invalidation_context( def process_edge_invalidation_llm_response( - edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] + edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] ) -> List[EntityEdge]: - invalidated_edges = [] - for edge_to_invalidate in edges_to_invalidate: - edge_uuid = edge_to_invalidate['edge_uuid'] - edge_to_update = next( - (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), - None, - ) - if edge_to_update: - edge_to_update.expired_at = datetime.now() - edge_to_update.fact = edge_to_invalidate["fact"] + invalidated_edges = [] + for edge_to_invalidate in edges_to_invalidate: + edge_uuid = edge_to_invalidate["edge_uuid"] + edge_to_update = next( + (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), + None, + ) + if edge_to_update: + edge_to_update.expired_at = datetime.now() + edge_to_update.fact = edge_to_invalidate["fact"] invalidated_edges.append(edge_to_update) logger.info( f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" - ) - return invalidated_edges + ) + return invalidated_edges diff --git a/core/utils/search/search_utils.py b/core/utils/search/search_utils.py index 15cc523..dcd3628 100644 --- a/core/utils/search/search_utils.py +++ b/core/utils/search/search_utils.py @@ -3,7 +3,8 @@ from datetime import datetime from time import time -from neo4j import AsyncDriver, time as neo4j_time +from neo4j import AsyncDriver +from neo4j import time as neo4j_time from core.edges import EntityEdge from core.nodes import EntityNode @@ -42,7 +43,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): for record in records: n_uuid = record["source_node_uuid"] - if n_uuid in context.keys(): + if n_uuid in context: context[n_uuid]["facts"].append(record["fact"]) else: context[n_uuid] = { diff --git a/runner.py b/runner.py index 6380401..5de65d0 100644 --- a/runner.py +++ b/runner.py @@ -11,36 +11,38 @@ load_dotenv() -neo4j_uri = os.environ.get('NEO4J_URI') or 'bolt://localhost:7687' -neo4j_user = os.environ.get('NEO4J_USER') or 'neo4j' -neo4j_password = os.environ.get('NEO4J_PASSWORD') or 'password' +neo4j_uri = os.environ.get("NEO4J_URI") or "bolt://localhost:7687" +neo4j_user = os.environ.get("NEO4J_USER") or "neo4j" +neo4j_password = os.environ.get("NEO4J_PASSWORD") or "password" def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # Create formatter + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) - # Add formatter to console handler - console_handler.setFormatter(formatter) + # Add formatter to console handler + console_handler.setFormatter(formatter) - # Add console handler to logger - logger.addHandler(console_handler) + # Add console handler to logger + logger.addHandler(console_handler) - return logger + return logger async def main(): - setup_logging() - client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - await clear_data(client.driver) + setup_logging() + client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + await clear_data(client.driver) # await client.build_indices() await client.add_episode(