diff --git a/core/edges.py b/core/edges.py index 76a516c..29978c9 100644 --- a/core/edges.py +++ b/core/edges.py @@ -14,41 +14,41 @@ class Edge(BaseModel, ABC): - uuid: str = Field(default_factory=lambda: uuid4().hex) - source_node_uuid: str - target_node_uuid: str - created_at: datetime + uuid: str = Field(default_factory=lambda: uuid4().hex) + source_node_uuid: str + target_node_uuid: str + created_at: datetime - @abstractmethod - async def save(self, driver: AsyncDriver): ... + @abstractmethod + async def save(self, driver: AsyncDriver): ... - def __hash__(self): - return hash(self.uuid) + def __hash__(self): + return hash(self.uuid) - def __eq__(self, other): - if isinstance(other, Node): - return self.uuid == other.uuid - return False + def __eq__(self, other): + if isinstance(other, Node): + return self.uuid == other.uuid + return False class EpisodicEdge(Edge): - async def save(self, driver: AsyncDriver): - result = await driver.execute_query( - """ + async def save(self, driver: AsyncDriver): + result = await driver.execute_query( + """ MATCH (episode:Episodic {uuid: $episode_uuid}) MATCH (node:Entity {uuid: $entity_uuid}) MERGE (episode)-[r:MENTIONS {uuid: $uuid}]->(node) SET r = {uuid: $uuid, created_at: $created_at} RETURN r.uuid AS uuid""", - episode_uuid=self.source_node_uuid, - entity_uuid=self.target_node_uuid, - uuid=self.uuid, - created_at=self.created_at, - ) + episode_uuid=self.source_node_uuid, + entity_uuid=self.target_node_uuid, + uuid=self.uuid, + created_at=self.created_at, + ) - logger.info(f'Saved edge to neo4j: {self.uuid}') + logger.info(f'Saved edge to neo4j: {self.uuid}') - return result + return result # TODO: Neo4j doesn't support variables for edge types and labels. @@ -56,38 +56,38 @@ async def save(self, driver: AsyncDriver): class EntityEdge(Edge): - name: str = Field(description='name of the edge, relation name') - fact: str = Field(description='fact representing the edge and nodes that it connects') - fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact') - episodes: list[str] | None = Field( - default=None, - description='list of episode ids that reference these entity edges', - ) - expired_at: datetime | None = Field( - default=None, description='datetime of when the node was invalidated' - ) - valid_at: datetime | None = Field( - default=None, description='datetime of when the fact became true' - ) - invalid_at: datetime | None = Field( - default=None, description='datetime of when the fact stopped being true' - ) - - async def generate_embedding(self, embedder, model='text-embedding-3-small'): - start = time() - - text = self.fact.replace('\n', ' ') - embedding = (await embedder.create(input=[text], model=model)).data[0].embedding - self.fact_embedding = embedding[:EMBEDDING_DIM] - - end = time() - logger.info(f'embedded {text} in {end-start} ms') - - return embedding - - async def save(self, driver: AsyncDriver): - result = await driver.execute_query( - """ + name: str = Field(description='name of the edge, relation name') + fact: str = Field(description='fact representing the edge and nodes that it connects') + fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact') + episodes: list[str] | None = Field( + default=None, + description='list of episode ids that reference these entity edges', + ) + expired_at: datetime | None = Field( + default=None, description='datetime of when the node was invalidated' + ) + valid_at: datetime | None = Field( + default=None, description='datetime of when the fact became true' + ) + invalid_at: datetime | None = Field( + default=None, description='datetime of when the fact stopped being true' + ) + + async def generate_embedding(self, embedder, model='text-embedding-3-small'): + start = time() + + text = self.fact.replace('\n', ' ') + embedding = (await embedder.create(input=[text], model=model)).data[0].embedding + self.fact_embedding = embedding[:EMBEDDING_DIM] + + end = time() + logger.info(f'embedded {text} in {end-start} ms') + + return embedding + + async def save(self, driver: AsyncDriver): + result = await driver.execute_query( + """ MATCH (source:Entity {uuid: $source_uuid}) MATCH (target:Entity {uuid: $target_uuid}) MERGE (source)-[r:RELATES_TO {uuid: $uuid}]->(target) @@ -95,19 +95,19 @@ async def save(self, driver: AsyncDriver): episodes: $episodes, created_at: $created_at, expired_at: $expired_at, valid_at: $valid_at, invalid_at: $invalid_at} RETURN r.uuid AS uuid""", - source_uuid=self.source_node_uuid, - target_uuid=self.target_node_uuid, - uuid=self.uuid, - name=self.name, - fact=self.fact, - fact_embedding=self.fact_embedding, - episodes=self.episodes, - created_at=self.created_at, - expired_at=self.expired_at, - valid_at=self.valid_at, - invalid_at=self.invalid_at, - ) - - logger.info(f'Saved edge to neo4j: {self.uuid}') - - return result + source_uuid=self.source_node_uuid, + target_uuid=self.target_node_uuid, + uuid=self.uuid, + name=self.name, + fact=self.fact, + fact_embedding=self.fact_embedding, + episodes=self.episodes, + created_at=self.created_at, + expired_at=self.expired_at, + valid_at=self.valid_at, + invalid_at=self.invalid_at, + ) + + logger.info(f'Saved edge to neo4j: {self.uuid}') + + return result diff --git a/core/graphiti.py b/core/graphiti.py index cc992d8..0c1a26c 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -13,33 +13,35 @@ from core.nodes import EntityNode, EpisodicNode from core.search.search import SearchConfig, hybrid_search from core.search.search_utils import ( - get_relevant_edges, - get_relevant_nodes, + get_relevant_edges, + get_relevant_nodes, ) from core.utils import ( - build_episodic_edges, - retrieve_episodes, + build_episodic_edges, + retrieve_episodes, ) from core.utils.bulk_utils import ( - BulkEpisode, - dedupe_edges_bulk, - dedupe_nodes_bulk, - extract_nodes_and_edges_bulk, - resolve_edge_pointers, - retrieve_previous_episodes_bulk, + BulkEpisode, + dedupe_edges_bulk, + dedupe_nodes_bulk, + extract_nodes_and_edges_bulk, + resolve_edge_pointers, + retrieve_previous_episodes_bulk, ) from core.utils.maintenance.edge_operations import ( - dedupe_extracted_edges, - extract_edges, + dedupe_extracted_edges, + extract_edges, ) from core.utils.maintenance.graph_data_operations import ( - EPISODE_WINDOW_LEN, - build_indices_and_constraints, + EPISODE_WINDOW_LEN, + build_indices_and_constraints, ) from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( - invalidate_edges, - prepare_edges_for_invalidation, + extract_edge_dates, + extract_node_edge_node_triplet, + invalidate_edges, + prepare_edges_for_invalidation, ) logger = logging.getLogger(__name__) @@ -48,268 +50,274 @@ class Graphiti: - def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): - self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) - self.database = 'neo4j' - if llm_client: - self.llm_client = llm_client - else: - self.llm_client = OpenAIClient( - LLMConfig( - api_key=os.getenv('OPENAI_API_KEY', default=''), - model='gpt-4o-mini', - base_url='https://api.openai.com/v1', - ) - ) - - def close(self): - self.driver.close() - - async def build_indices_and_constraints(self): - await build_indices_and_constraints(self.driver) - - async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, - ) -> list[EpisodicNode]: - """Retrieve the last n episodic nodes from the graph""" - return await retrieve_episodes(self.driver, reference_time, last_n) - - async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - success_callback: Callable | None = None, - error_callback: Callable | None = None, - ): - """Process an episode and update the graph""" - try: - start = time() - - nodes: list[EntityNode] = [] - entity_edges: list[EntityEdge] = [] - episodic_edges: list[EpisodicEdge] = [] - embedder = self.llm_client.get_embedder() - now = datetime.now() - - previous_episodes = await self.retrieve_episodes(reference_time) - episode = EpisodicNode( - name=name, - labels=[], - source='messages', - content=episode_body, - source_description=source_description, - created_at=now, - valid_at=reference_time, - ) - - extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) - logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') - - # Calculate Embeddings - - await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes] - ) - existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver) - logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') - touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes( - self.llm_client, extracted_nodes, existing_nodes - ) - logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}') - nodes.extend(touched_nodes) - - extracted_edges = await extract_edges( - self.llm_client, episode, touched_nodes, previous_episodes - ) - - await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) - - existing_edges = await get_relevant_edges(extracted_edges, self.driver) - logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}') - logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') - - # deduped_edges = await dedupe_extracted_edges_v2( - # self.llm_client, - # extract_node_and_edge_triplets(extracted_edges, nodes), - # extract_node_and_edge_triplets(existing_edges, nodes), - # ) - - deduped_edges = await dedupe_extracted_edges( - self.llm_client, - extracted_edges, - existing_edges, - ) - - edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] - for edge in deduped_edges: - edge_touched_node_uuids.append(edge.source_node_uuid) - edge_touched_node_uuids.append(edge.target_node_uuid) - - ( - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, - ) = prepare_edges_for_invalidation( - existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes - ) - - invalidated_edges = await invalidate_edges( - self.llm_client, - old_edges_with_nodes_pending_invalidation, - new_edges_with_nodes, - episode, - previous_episodes, - ) - - for edge in invalidated_edges: - edge_touched_node_uuids.append(edge.source_node_uuid) - edge_touched_node_uuids.append(edge.target_node_uuid) - - edges_to_save = invalidated_edges - - # There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one - for deduped_edge in deduped_edges: - if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]: - edges_to_save.append(deduped_edge) - - entity_edges.extend(edges_to_save) - - edge_touched_node_uuids = list(set(edge_touched_node_uuids)) - involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids] - - logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}') - - logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') - - logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') - - episodic_edges.extend( - build_episodic_edges( - # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them - involved_nodes, - episode, - now, - ) - ) - # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built - logger.info(f'Built episodic edges: {episodic_edges}') - - # Future optimization would be using batch operations to save nodes and edges - await episode.save(self.driver) - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) - await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) - - end = time() - logger.info(f'Completed add_episode in {(end-start) * 1000} ms') - # for node in nodes: - # if isinstance(node, EntityNode): - # await node.update_summary(self.driver) - if success_callback: - await success_callback(episode) - except Exception as e: - if error_callback: - await error_callback(episode, e) - else: - raise e - - async def add_episode_bulk( - self, - bulk_episodes: list[BulkEpisode], - ): - try: - start = time() - embedder = self.llm_client.get_embedder() - now = datetime.now() - - episodes = [ - EpisodicNode( - name=episode.name, - labels=[], - source='messages', - content=episode.content, - source_description=episode.source_description, - created_at=now, - valid_at=episode.reference_time, - ) - for episode in bulk_episodes - ] - - # Save all the episodes - await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) - - # Get previous episode context for each episode - episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) - - # Extract all nodes and edges - ( - extracted_nodes, - extracted_edges, - episodic_edges, - ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) - - # Generate embeddings - await asyncio.gather( - *[node.generate_name_embedding(embedder) for node in extracted_nodes], - *[edge.generate_embedding(embedder) for edge in extracted_edges], - ) - - # Dedupe extracted nodes - nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes) - - # save nodes to KG - await asyncio.gather(*[node.save(self.driver) for node in nodes]) - - # re-map edge pointers so that they don't point to discard dupe nodes - extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( - extracted_edges, uuid_map - ) - episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers( - episodic_edges, uuid_map - ) - - # save episodic edges to KG - await asyncio.gather( - *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers] - ) - - # Dedupe extracted edges - edges = await dedupe_edges_bulk( - self.driver, self.llm_client, extracted_edges_with_resolved_pointers - ) - logger.info(f'extracted edge length: {len(edges)}') - - # invalidate edges - - # save edges to KG - await asyncio.gather(*[edge.save(self.driver) for edge in edges]) - - end = time() - logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms') - - except Exception as e: - raise e - - async def search(self, query: str, num_results=10): - search_config = SearchConfig(num_episodes=0, num_results=num_results) - edges = ( - await hybrid_search( - self.driver, - self.llm_client.get_embedder(), - query, - datetime.now(), - search_config, - ) - ).edges - - facts = [edge.fact for edge in edges] - - return facts - - async def _search(self, query: str, timestamp: datetime, config: SearchConfig): - return await hybrid_search( - self.driver, self.llm_client.get_embedder(), query, timestamp, config - ) + def __init__(self, uri: str, user: str, password: str, llm_client: LLMClient | None = None): + self.driver = AsyncGraphDatabase.driver(uri, auth=(user, password)) + self.database = 'neo4j' + if llm_client: + self.llm_client = llm_client + else: + self.llm_client = OpenAIClient( + LLMConfig( + api_key=os.getenv('OPENAI_API_KEY', default=''), + model='gpt-4o-mini', + base_url='https://api.openai.com/v1', + ) + ) + + def close(self): + self.driver.close() + + async def build_indices_and_constraints(self): + await build_indices_and_constraints(self.driver) + + async def retrieve_episodes( + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, + ) -> list[EpisodicNode]: + """Retrieve the last n episodic nodes from the graph""" + return await retrieve_episodes(self.driver, reference_time, last_n) + + async def add_episode( + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + success_callback: Callable | None = None, + error_callback: Callable | None = None, + ): + """Process an episode and update the graph""" + try: + start = time() + + nodes: list[EntityNode] = [] + entity_edges: list[EntityEdge] = [] + episodic_edges: list[EpisodicEdge] = [] + embedder = self.llm_client.get_embedder() + now = datetime.now() + + previous_episodes = await self.retrieve_episodes(reference_time) + episode = EpisodicNode( + name=name, + labels=[], + source='messages', + content=episode_body, + source_description=source_description, + created_at=now, + valid_at=reference_time, + ) + + extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes) + logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') + + # Calculate Embeddings + + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes] + ) + existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver) + logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}') + touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes( + self.llm_client, extracted_nodes, existing_nodes + ) + logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}') + nodes.extend(touched_nodes) + + extracted_edges = await extract_edges( + self.llm_client, episode, touched_nodes, previous_episodes + ) + + await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges]) + + existing_edges = await get_relevant_edges(extracted_edges, self.driver) + logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}') + logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}') + + # deduped_edges = await dedupe_extracted_edges_v2( + # self.llm_client, + # extract_node_and_edge_triplets(extracted_edges, nodes), + # extract_node_and_edge_triplets(existing_edges, nodes), + # ) + + deduped_edges = await dedupe_extracted_edges( + self.llm_client, + extracted_edges, + existing_edges, + ) + + edge_touched_node_uuids = [n.uuid for n in brand_new_nodes] + for edge in deduped_edges: + edge_touched_node_uuids.append(edge.source_node_uuid) + edge_touched_node_uuids.append(edge.target_node_uuid) + + ( + old_edges_with_nodes_pending_invalidation, + new_edges_with_nodes, + ) = prepare_edges_for_invalidation( + existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes + ) + + invalidated_edges = await invalidate_edges( + self.llm_client, + old_edges_with_nodes_pending_invalidation, + new_edges_with_nodes, + episode, + previous_episodes, + ) + + for edge in invalidated_edges: + edge_touched_node_uuids.append(edge.source_node_uuid) + edge_touched_node_uuids.append(edge.target_node_uuid) + + edges_to_save = invalidated_edges + + # There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one + for deduped_edge in deduped_edges: + if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]: + edges_to_save.append(deduped_edge) + for deduped_edge in deduped_edges: + triplet = extract_node_edge_node_triplet(deduped_edge, nodes) + valid_at, invalid_at, _ = await extract_edge_dates( + self.llm_client, triplet, episode.valid_at, episode, previous_episodes + ) + deduped_edge.valid_at = valid_at + deduped_edge.invalid_at = invalid_at + entity_edges.extend(edges_to_save) + + edge_touched_node_uuids = list(set(edge_touched_node_uuids)) + involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids] + + logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}') + + logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}') + + logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}') + + episodic_edges.extend( + build_episodic_edges( + # There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them + involved_nodes, + episode, + now, + ) + ) + # Important to append the episode to the nodes at the end so that self referencing episodic edges are not built + logger.info(f'Built episodic edges: {episodic_edges}') + + # Future optimization would be using batch operations to save nodes and edges + await episode.save(self.driver) + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + await asyncio.gather(*[edge.save(self.driver) for edge in episodic_edges]) + await asyncio.gather(*[edge.save(self.driver) for edge in entity_edges]) + + end = time() + logger.info(f'Completed add_episode in {(end-start) * 1000} ms') + # for node in nodes: + # if isinstance(node, EntityNode): + # await node.update_summary(self.driver) + if success_callback: + await success_callback(episode) + except Exception as e: + if error_callback: + await error_callback(episode, e) + else: + raise e + + async def add_episode_bulk( + self, + bulk_episodes: list[BulkEpisode], + ): + try: + start = time() + embedder = self.llm_client.get_embedder() + now = datetime.now() + + episodes = [ + EpisodicNode( + name=episode.name, + labels=[], + source='messages', + content=episode.content, + source_description=episode.source_description, + created_at=now, + valid_at=episode.reference_time, + ) + for episode in bulk_episodes + ] + + # Save all the episodes + await asyncio.gather(*[episode.save(self.driver) for episode in episodes]) + + # Get previous episode context for each episode + episode_pairs = await retrieve_previous_episodes_bulk(self.driver, episodes) + + # Extract all nodes and edges + ( + extracted_nodes, + extracted_edges, + episodic_edges, + ) = await extract_nodes_and_edges_bulk(self.llm_client, episode_pairs) + + # Generate embeddings + await asyncio.gather( + *[node.generate_name_embedding(embedder) for node in extracted_nodes], + *[edge.generate_embedding(embedder) for edge in extracted_edges], + ) + + # Dedupe extracted nodes + nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes) + + # save nodes to KG + await asyncio.gather(*[node.save(self.driver) for node in nodes]) + + # re-map edge pointers so that they don't point to discard dupe nodes + extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers( + extracted_edges, uuid_map + ) + episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers( + episodic_edges, uuid_map + ) + + # save episodic edges to KG + await asyncio.gather( + *[edge.save(self.driver) for edge in episodic_edges_with_resolved_pointers] + ) + + # Dedupe extracted edges + edges = await dedupe_edges_bulk( + self.driver, self.llm_client, extracted_edges_with_resolved_pointers + ) + logger.info(f'extracted edge length: {len(edges)}') + + # invalidate edges + + # save edges to KG + await asyncio.gather(*[edge.save(self.driver) for edge in edges]) + + end = time() + logger.info(f'Completed add_episode_bulk in {(end-start) * 1000} ms') + + except Exception as e: + raise e + + async def search(self, query: str, num_results=10): + search_config = SearchConfig(num_episodes=0, num_results=num_results) + edges = ( + await hybrid_search( + self.driver, + self.llm_client.get_embedder(), + query, + datetime.now(), + search_config, + ) + ).edges + + facts = [edge.fact for edge in edges] + + return facts + + async def _search(self, query: str, timestamp: datetime, config: SearchConfig): + return await hybrid_search( + self.driver, self.llm_client.get_embedder(), query, timestamp, config + ) diff --git a/core/llm_client/client.py b/core/llm_client/client.py index 911c9fd..0c5c7f7 100644 --- a/core/llm_client/client.py +++ b/core/llm_client/client.py @@ -6,14 +6,14 @@ class LLMClient(ABC): - @abstractmethod - def __init__(self, config: LLMConfig): - pass + @abstractmethod + def __init__(self, config: LLMConfig): + pass - @abstractmethod - def get_embedder(self) -> typing.Any: - pass + @abstractmethod + def get_embedder(self) -> typing.Any: + pass - @abstractmethod - async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: - pass + @abstractmethod + async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: + pass diff --git a/core/llm_client/config.py b/core/llm_client/config.py index 3de29df..a591ecf 100644 --- a/core/llm_client/config.py +++ b/core/llm_client/config.py @@ -2,35 +2,35 @@ class LLMConfig: - """ - Configuration class for the Language Learning Model (LLM). + """ + Configuration class for the Language Learning Model (LLM). - This class encapsulates the necessary parameters to interact with an LLM API, - such as OpenAI's GPT models. It stores the API key, model name, and base URL - for making requests to the LLM service. - """ + This class encapsulates the necessary parameters to interact with an LLM API, + such as OpenAI's GPT models. It stores the API key, model name, and base URL + for making requests to the LLM service. + """ - def __init__( - self, - api_key: str, - model: str = 'gpt-4o-mini', - base_url: str = 'https://api.openai.com', - ): - """ - Initialize the LLMConfig with the provided parameters. + def __init__( + self, + api_key: str, + model: str = 'gpt-4o-mini', + base_url: str = 'https://api.openai.com', + ): + """ + Initialize the LLMConfig with the provided parameters. - Args: - api_key (str): The authentication key for accessing the LLM API. - This is required for making authorized requests. + Args: + api_key (str): The authentication key for accessing the LLM API. + This is required for making authorized requests. - model (str, optional): The specific LLM model to use for generating responses. - Defaults to "gpt-4o-mini", which appears to be a custom model name. - Common values might include "gpt-3.5-turbo" or "gpt-4". + model (str, optional): The specific LLM model to use for generating responses. + Defaults to "gpt-4o-mini", which appears to be a custom model name. + Common values might include "gpt-3.5-turbo" or "gpt-4". - base_url (str, optional): The base URL of the LLM API service. - Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint. - This can be changed if using a different provider or a custom endpoint. - """ - self.base_url = base_url - self.api_key = api_key - self.model = model + base_url (str, optional): The base URL of the LLM API service. + Defaults to "https://api.openai.com", which is OpenAI's standard API endpoint. + This can be changed if using a different provider or a custom endpoint. + """ + self.base_url = base_url + self.api_key = api_key + self.model = model diff --git a/core/llm_client/openai_client.py b/core/llm_client/openai_client.py index feb096f..e5f4ed8 100644 --- a/core/llm_client/openai_client.py +++ b/core/llm_client/openai_client.py @@ -13,30 +13,30 @@ class OpenAIClient(LLMClient): - def __init__(self, config: LLMConfig): - self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) - self.model = config.model + def __init__(self, config: LLMConfig): + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + self.model = config.model - def get_embedder(self) -> typing.Any: - return self.client.embeddings + def get_embedder(self) -> typing.Any: + return self.client.embeddings - async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: - openai_messages: list[ChatCompletionMessageParam] = [] - for m in messages: - if m.role == 'user': - openai_messages.append({'role': 'user', 'content': m.content}) - elif m.role == 'system': - openai_messages.append({'role': 'system', 'content': m.content}) - try: - response = await self.client.chat.completions.create( - model=self.model, - messages=openai_messages, - temperature=0.1, - max_tokens=3000, - response_format={'type': 'json_object'}, - ) - result = response.choices[0].message.content or '' - return json.loads(result) - except Exception as e: - logger.error(f'Error in generating LLM response: {e}') - raise + async def generate_response(self, messages: list[Message]) -> dict[str, typing.Any]: + openai_messages: list[ChatCompletionMessageParam] = [] + for m in messages: + if m.role == 'user': + openai_messages.append({'role': 'user', 'content': m.content}) + elif m.role == 'system': + openai_messages.append({'role': 'system', 'content': m.content}) + try: + response = await self.client.chat.completions.create( + model=self.model, + messages=openai_messages, + temperature=0.1, + max_tokens=3000, + response_format={'type': 'json_object'}, + ) + result = response.choices[0].message.content or '' + return json.loads(result) + except Exception as e: + logger.error(f'Error in generating LLM response: {e}') + raise diff --git a/core/nodes.py b/core/nodes.py index f60aedd..3015a42 100644 --- a/core/nodes.py +++ b/core/nodes.py @@ -14,89 +14,89 @@ class Node(BaseModel, ABC): - uuid: str = Field(default_factory=lambda: uuid4().hex) - name: str - labels: list[str] = Field(default_factory=list) - created_at: datetime + uuid: str = Field(default_factory=lambda: uuid4().hex) + name: str + labels: list[str] = Field(default_factory=list) + created_at: datetime - @abstractmethod - async def save(self, driver: AsyncDriver): ... + @abstractmethod + async def save(self, driver: AsyncDriver): ... - def __hash__(self): - return hash(self.uuid) + def __hash__(self): + return hash(self.uuid) - def __eq__(self, other): - if isinstance(other, Node): - return self.uuid == other.uuid - return False + def __eq__(self, other): + if isinstance(other, Node): + return self.uuid == other.uuid + return False class EpisodicNode(Node): - source: str = Field(description='source type') - source_description: str = Field(description='description of the data source') - content: str = Field(description='raw episode data') - valid_at: datetime = Field( - description='datetime of when the original document was created', - ) - entity_edges: list[str] = Field( - description='list of entity edges referenced in this episode', - default_factory=list, - ) - - async def save(self, driver: AsyncDriver): - result = await driver.execute_query( - """ + source: str = Field(description='source type') + source_description: str = Field(description='description of the data source') + content: str = Field(description='raw episode data') + valid_at: datetime = Field( + description='datetime of when the original document was created', + ) + entity_edges: list[str] = Field( + description='list of entity edges referenced in this episode', + default_factory=list, + ) + + async def save(self, driver: AsyncDriver): + result = await driver.execute_query( + """ MERGE (n:Episodic {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, source_description: $source_description, source: $source, content: $content, entity_edges: $entity_edges, created_at: $created_at, valid_at: $valid_at} RETURN n.uuid AS uuid""", - uuid=self.uuid, - name=self.name, - source_description=self.source_description, - content=self.content, - entity_edges=self.entity_edges, - created_at=self.created_at, - valid_at=self.valid_at, - source=self.source, - _database='neo4j', - ) + uuid=self.uuid, + name=self.name, + source_description=self.source_description, + content=self.content, + entity_edges=self.entity_edges, + created_at=self.created_at, + valid_at=self.valid_at, + source=self.source, + _database='neo4j', + ) - logger.info(f'Saved Node to neo4j: {self.uuid}') + logger.info(f'Saved Node to neo4j: {self.uuid}') - return result + return result class EntityNode(Node): - name_embedding: list[float] | None = Field(default=None, description='embedding of the name') - summary: str = Field(description='regional summary of surrounding edges', default_factory=str) + name_embedding: list[float] | None = Field(default=None, description='embedding of the name') + summary: str = Field(description='regional summary of surrounding edges', default_factory=str) - async def update_summary(self, driver: AsyncDriver): ... + async def update_summary(self, driver: AsyncDriver): ... - async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... + async def refresh_summary(self, driver: AsyncDriver, llm_client: OpenAI): ... - async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): - start = time() - text = self.name.replace('\n', ' ') - embedding = (await embedder.create(input=[text], model=model)).data[0].embedding - self.name_embedding = embedding[:EMBEDDING_DIM] - end = time() - logger.info(f'embedded {text} in {end-start} ms') + async def generate_name_embedding(self, embedder, model='text-embedding-3-small'): + start = time() + text = self.name.replace('\n', ' ') + embedding = (await embedder.create(input=[text], model=model)).data[0].embedding + self.name_embedding = embedding[:EMBEDDING_DIM] + end = time() + logger.info(f'embedded {text} in {end-start} ms') - return embedding + return embedding - async def save(self, driver: AsyncDriver): - result = await driver.execute_query( - """ + async def save(self, driver: AsyncDriver): + result = await driver.execute_query( + """ MERGE (n:Entity {uuid: $uuid}) SET n = {uuid: $uuid, name: $name, name_embedding: $name_embedding, summary: $summary, created_at: $created_at} RETURN n.uuid AS uuid""", - uuid=self.uuid, - name=self.name, - summary=self.summary, - name_embedding=self.name_embedding, - created_at=self.created_at, - ) + uuid=self.uuid, + name=self.name, + summary=self.summary, + name_embedding=self.name_embedding, + created_at=self.created_at, + ) - logger.info(f'Saved Node to neo4j: {self.uuid}') + logger.info(f'Saved Node to neo4j: {self.uuid}') - return result + return result diff --git a/core/prompts/dedupe_nodes.py b/core/prompts/dedupe_nodes.py index 6c3f459..c3e2088 100644 --- a/core/prompts/dedupe_nodes.py +++ b/core/prompts/dedupe_nodes.py @@ -5,26 +5,26 @@ class Prompt(Protocol): - v1: PromptVersion - v2: PromptVersion - node_list: PromptVersion + v1: PromptVersion + v2: PromptVersion + node_list: PromptVersion class Versions(TypedDict): - v1: PromptFunction - v2: PromptFunction - node_list: PromptVersion + v1: PromptFunction + v2: PromptFunction + node_list: PromptVersion def v1(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that de-duplicates nodes from node lists.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates nodes from node lists.', + ), + Message( + role='user', + content=f""" Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes: Existing Nodes: @@ -52,19 +52,19 @@ def v1(context: dict[str, Any]) -> list[Message]: ] }} """, - ), - ] + ), + ] def v2(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that de-duplicates nodes from node lists.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates nodes from node lists.', + ), + Message( + role='user', + content=f""" Given the following context, deduplicate nodes from a list of new nodes given a list of existing nodes: Existing Nodes: @@ -92,19 +92,19 @@ def v2(context: dict[str, Any]) -> list[Message]: ] }} """, - ), - ] + ), + ] def node_list(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that de-duplicates nodes from node lists.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that de-duplicates nodes from node lists.', + ), + Message( + role='user', + content=f""" Given the following context, deduplicate a list of nodes: Nodes: @@ -127,8 +127,8 @@ def node_list(context: dict[str, Any]) -> list[Message]: ] }} """, - ), - ] + ), + ] versions: Versions = {'v1': v1, 'v2': v2, 'node_list': node_list} diff --git a/core/prompts/extract_edge_dates.py b/core/prompts/extract_edge_dates.py new file mode 100644 index 0000000..cae639a --- /dev/null +++ b/core/prompts/extract_edge_dates.py @@ -0,0 +1,62 @@ +from typing import Any, Protocol, TypedDict + +from .models import Message, PromptFunction, PromptVersion + + +class Prompt(Protocol): + v1: PromptVersion + + +class Versions(TypedDict): + v1: PromptFunction + + +def v1(context: dict[str, Any]) -> list[Message]: + return [ + Message( + role='system', + content='You are an AI assistant that extracts datetime information for graph edges, focusing only on dates directly related to the establishment or change of the relationship described in the edge fact.', + ), + Message( + role='user', + content=f""" + Edge: + Source Node: {context['source_node']} + Edge Name: {context['edge_name']} + Target Node: {context['target_node']} + Fact: {context['edge_fact']} + + Current Episode: {context['current_episode']} + Previous Episodes: {context['previous_episodes']} + Reference Timestamp: {context['reference_timestamp']} + + IMPORTANT: Only extract time information if it is part of the provided fact. Otherwise ignore the time mentioned. Make sure to do your best to determine the dates if only the relative time is mentioned. (eg 10 years ago, 2 mins ago) based on the provided reference timestamp + + Definitions: + - valid_at: The date and time when the relationship described by the edge fact became true or was established. + - invalid_at: The date and time when the relationship described by the edge fact stopped being true or ended. + + Task: + Analyze the conversation and determine if there are dates that are part of the edge fact. Only set dates if they explicitly relate to the formation or alteration of the relationship itself. + + Guidelines: + 1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes. + 2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates. + 3. If no temporal information is found that establishes or changes the relationship, leave the fields as null. + 4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship. + 5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp. + 6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date. + 7. If only a year is mentioned, use January 1st of that year at 00:00:00. + 9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned). + Respond with a JSON object: + {{ + "valid_at": "YYYY-MM-DDTHH:MM:SSZ or null", + "invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null", + "explanation": "Brief explanation of why these dates were chosen or why they were set to null" + }} + """, + ), + ] + + +versions: Versions = {'v1': v1} diff --git a/core/prompts/extract_edges.py b/core/prompts/extract_edges.py index c339f63..a57d202 100644 --- a/core/prompts/extract_edges.py +++ b/core/prompts/extract_edges.py @@ -5,24 +5,24 @@ class Prompt(Protocol): - v1: PromptVersion - v2: PromptVersion + v1: PromptVersion + v2: PromptVersion class Versions(TypedDict): - v1: PromptFunction - v2: PromptFunction + v1: PromptFunction + v2: PromptFunction def v1(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that extracts graph edges from provided context.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that extracts graph edges from provided context.', + ), + Message( + role='user', + content=f""" Given the following context, extract new semantic edges (relationships) that need to be added to the knowledge graph: Current Graph Structure: @@ -66,19 +66,19 @@ def v1(context: dict[str, Any]) -> list[Message]: If no new edges need to be added, return an empty list for "new_edges". """, - ), - ] + ), + ] def v2(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that extracts graph edges from provided context.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that extracts graph edges from provided context.', + ), + Message( + role='user', + content=f""" Given the following context, extract edges (relationships) that need to be added to the knowledge graph: Nodes: {json.dumps(context['nodes'], indent=2)} @@ -115,8 +115,8 @@ def v2(context: dict[str, Any]) -> list[Message]: If no edges need to be added, return an empty list for "edges". """, - ), - ] + ), + ] versions: Versions = {'v1': v1, 'v2': v2} diff --git a/core/prompts/extract_nodes.py b/core/prompts/extract_nodes.py index 7278568..8b75aff 100644 --- a/core/prompts/extract_nodes.py +++ b/core/prompts/extract_nodes.py @@ -5,26 +5,26 @@ class Prompt(Protocol): - v1: PromptVersion - v2: PromptVersion - v3: PromptVersion + v1: PromptVersion + v2: PromptVersion + v3: PromptVersion class Versions(TypedDict): - v1: PromptFunction - v2: PromptFunction - v3: PromptFunction + v1: PromptFunction + v2: PromptFunction + v3: PromptFunction def v1(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that extracts graph nodes from provided context.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that extracts graph nodes from provided context.', + ), + Message( + role='user', + content=f""" Given the following context, extract new semantic nodes that need to be added to the knowledge graph: Existing Nodes: @@ -60,19 +60,19 @@ def v1(context: dict[str, Any]) -> list[Message]: If no new nodes need to be added, return an empty list for "new_nodes". """, - ), - ] + ), + ] def v2(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are a helpful assistant that extracts graph nodes from provided context.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are a helpful assistant that extracts graph nodes from provided context.', + ), + Message( + role='user', + content=f""" Given the following context, extract new entity nodes that need to be added to the knowledge graph: Previous Episodes: @@ -101,14 +101,14 @@ def v2(context: dict[str, Any]) -> list[Message]: If no new nodes need to be added, return an empty list for "new_nodes". """, - ), - ] + ), + ] def v3(context: dict[str, Any]) -> list[Message]: - sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation.""" + sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation.""" - user_prompt = f""" + user_prompt = f""" Given the following conversation, extract entity nodes that are explicitly or implicitly mentioned: Conversation: @@ -120,6 +120,7 @@ def v3(context: dict[str, Any]) -> list[Message]: 2. Extract other significant entities, concepts, or actors mentioned in the conversation. 3. Provide concise but informative summaries for each extracted node. 4. Avoid creating nodes for relationships or actions. +5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later). Respond with a JSON object in the following format: {{ @@ -132,10 +133,10 @@ def v3(context: dict[str, Any]) -> list[Message]: ] }} """ - return [ - Message(role='system', content=sys_prompt), - Message(role='user', content=user_prompt), - ] + return [ + Message(role='system', content=sys_prompt), + Message(role='user', content=user_prompt), + ] versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3} diff --git a/core/prompts/invalidate_edges.py b/core/prompts/invalidate_edges.py index 6b5667e..23dd967 100644 --- a/core/prompts/invalidate_edges.py +++ b/core/prompts/invalidate_edges.py @@ -4,22 +4,22 @@ class Prompt(Protocol): - v1: PromptVersion + v1: PromptVersion class Versions(TypedDict): - v1: PromptFunction + v1: PromptFunction def v1(context: dict[str, Any]) -> list[Message]: - return [ - Message( - role='system', - content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.', - ), - Message( - role='user', - content=f""" + return [ + Message( + role='system', + content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.', + ), + Message( + role='user', + content=f""" Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges. Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true. Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships. @@ -50,8 +50,8 @@ def v1(context: dict[str, Any]) -> list[Message]: If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges". """, - ), - ] + ), + ] versions: Versions = {'v1': v1} diff --git a/core/prompts/lib.py b/core/prompts/lib.py index d42914c..69e73bc 100644 --- a/core/prompts/lib.py +++ b/core/prompts/lib.py @@ -1,94 +1,106 @@ from typing import Any, Protocol, TypedDict from .dedupe_edges import ( - Prompt as DedupeEdgesPrompt, + Prompt as DedupeEdgesPrompt, ) from .dedupe_edges import ( - Versions as DedupeEdgesVersions, + Versions as DedupeEdgesVersions, ) from .dedupe_edges import ( - versions as dedupe_edges_versions, + versions as dedupe_edges_versions, ) from .dedupe_nodes import ( - Prompt as DedupeNodesPrompt, + Prompt as DedupeNodesPrompt, ) from .dedupe_nodes import ( - Versions as DedupeNodesVersions, + Versions as DedupeNodesVersions, ) from .dedupe_nodes import ( - versions as dedupe_nodes_versions, + versions as dedupe_nodes_versions, +) +from .extract_edge_dates import ( + Prompt as ExtractEdgeDatesPrompt, +) +from .extract_edge_dates import ( + Versions as ExtractEdgeDatesVersions, +) +from .extract_edge_dates import ( + versions as extract_edge_dates_versions, ) from .extract_edges import ( - Prompt as ExtractEdgesPrompt, + Prompt as ExtractEdgesPrompt, ) from .extract_edges import ( - Versions as ExtractEdgesVersions, + Versions as ExtractEdgesVersions, ) from .extract_edges import ( - versions as extract_edges_versions, + versions as extract_edges_versions, ) from .extract_nodes import ( - Prompt as ExtractNodesPrompt, + Prompt as ExtractNodesPrompt, ) from .extract_nodes import ( - Versions as ExtractNodesVersions, + Versions as ExtractNodesVersions, ) from .extract_nodes import ( - versions as extract_nodes_versions, + versions as extract_nodes_versions, ) from .invalidate_edges import ( - Prompt as InvalidateEdgesPrompt, + Prompt as InvalidateEdgesPrompt, ) from .invalidate_edges import ( - Versions as InvalidateEdgesVersions, + Versions as InvalidateEdgesVersions, ) from .invalidate_edges import ( - versions as invalidate_edges_versions, + versions as invalidate_edges_versions, ) from .models import Message, PromptFunction class PromptLibrary(Protocol): - extract_nodes: ExtractNodesPrompt - dedupe_nodes: DedupeNodesPrompt - extract_edges: ExtractEdgesPrompt - dedupe_edges: DedupeEdgesPrompt - invalidate_edges: InvalidateEdgesPrompt + extract_nodes: ExtractNodesPrompt + dedupe_nodes: DedupeNodesPrompt + extract_edges: ExtractEdgesPrompt + dedupe_edges: DedupeEdgesPrompt + invalidate_edges: InvalidateEdgesPrompt + extract_edge_dates: ExtractEdgeDatesPrompt class PromptLibraryImpl(TypedDict): - extract_nodes: ExtractNodesVersions - dedupe_nodes: DedupeNodesVersions - extract_edges: ExtractEdgesVersions - dedupe_edges: DedupeEdgesVersions - invalidate_edges: InvalidateEdgesVersions + extract_nodes: ExtractNodesVersions + dedupe_nodes: DedupeNodesVersions + extract_edges: ExtractEdgesVersions + dedupe_edges: DedupeEdgesVersions + invalidate_edges: InvalidateEdgesVersions + extract_edge_dates: ExtractEdgeDatesVersions class VersionWrapper: - def __init__(self, func: PromptFunction): - self.func = func + def __init__(self, func: PromptFunction): + self.func = func - def __call__(self, context: dict[str, Any]) -> list[Message]: - return self.func(context) + def __call__(self, context: dict[str, Any]) -> list[Message]: + return self.func(context) class PromptTypeWrapper: - def __init__(self, versions: dict[str, PromptFunction]): - for version, func in versions.items(): - setattr(self, version, VersionWrapper(func)) + def __init__(self, versions: dict[str, PromptFunction]): + for version, func in versions.items(): + setattr(self, version, VersionWrapper(func)) class PromptLibraryWrapper: - def __init__(self, library: PromptLibraryImpl): - for prompt_type, versions in library.items(): - setattr(self, prompt_type, PromptTypeWrapper(versions)) # type: ignore[arg-type] + def __init__(self, library: PromptLibraryImpl): + for prompt_type, versions in library.items(): + setattr(self, prompt_type, PromptTypeWrapper(versions)) # type: ignore[arg-type] PROMPT_LIBRARY_IMPL: PromptLibraryImpl = { - 'extract_nodes': extract_nodes_versions, - 'dedupe_nodes': dedupe_nodes_versions, - 'extract_edges': extract_edges_versions, - 'dedupe_edges': dedupe_edges_versions, - 'invalidate_edges': invalidate_edges_versions, + 'extract_nodes': extract_nodes_versions, + 'dedupe_nodes': dedupe_nodes_versions, + 'extract_edges': extract_edges_versions, + 'dedupe_edges': dedupe_edges_versions, + 'invalidate_edges': invalidate_edges_versions, + 'extract_edge_dates': extract_edge_dates_versions, } prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment] diff --git a/core/prompts/models.py b/core/prompts/models.py index 708a3ea..b1dc041 100644 --- a/core/prompts/models.py +++ b/core/prompts/models.py @@ -4,12 +4,12 @@ class Message(BaseModel): - role: str - content: str + role: str + content: str class PromptVersion(Protocol): - def __call__(self, context: dict[str, Any]) -> list[Message]: ... + def __call__(self, context: dict[str, Any]) -> list[Message]: ... PromptFunction = Callable[[dict[str, Any]], list[Message]] diff --git a/core/search/search.py b/core/search/search.py index a773d8f..c897fe3 100644 --- a/core/search/search.py +++ b/core/search/search.py @@ -9,10 +9,10 @@ from core.llm_client.config import EMBEDDING_DIM from core.nodes import EntityNode, EpisodicNode from core.search.search_utils import ( - edge_fulltext_search, - edge_similarity_search, - get_mentioned_nodes, - rrf, + edge_fulltext_search, + edge_similarity_search, + get_mentioned_nodes, + rrf, ) from core.utils import retrieve_episodes from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN @@ -21,81 +21,81 @@ class SearchConfig(BaseModel): - num_results: int = 10 - num_episodes: int = EPISODE_WINDOW_LEN - similarity_search: str = 'cosine' - text_search: str = 'BM25' - reranker: str = 'rrf' + num_results: int = 10 + num_episodes: int = EPISODE_WINDOW_LEN + similarity_search: str = 'cosine' + text_search: str = 'BM25' + reranker: str = 'rrf' class SearchResults(BaseModel): - episodes: list[EpisodicNode] - nodes: list[EntityNode] - edges: list[EntityEdge] + episodes: list[EpisodicNode] + nodes: list[EntityNode] + edges: list[EntityEdge] async def hybrid_search( - driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig + driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig ) -> SearchResults: - start = time() + start = time() - episodes = [] - nodes = [] - edges = [] + episodes = [] + nodes = [] + edges = [] - search_results = [] + search_results = [] - if config.num_episodes > 0: - episodes.extend(await retrieve_episodes(driver, timestamp)) - nodes.extend(await get_mentioned_nodes(driver, episodes)) + if config.num_episodes > 0: + episodes.extend(await retrieve_episodes(driver, timestamp)) + nodes.extend(await get_mentioned_nodes(driver, episodes)) - if config.text_search == 'BM25': - text_search = await edge_fulltext_search(query, driver) - search_results.append(text_search) + if config.text_search == 'BM25': + text_search = await edge_fulltext_search(query, driver) + search_results.append(text_search) - if config.similarity_search == 'cosine': - query_text = query.replace('\n', ' ') - search_vector = ( - (await embedder.create(input=[query_text], model='text-embedding-3-small')) - .data[0] - .embedding[:EMBEDDING_DIM] - ) + if config.similarity_search == 'cosine': + query_text = query.replace('\n', ' ') + search_vector = ( + (await embedder.create(input=[query_text], model='text-embedding-3-small')) + .data[0] + .embedding[:EMBEDDING_DIM] + ) - similarity_search = await edge_similarity_search(search_vector, driver) - search_results.append(similarity_search) + similarity_search = await edge_similarity_search(search_vector, driver) + search_results.append(similarity_search) - if len(search_results) == 1: - edges = search_results[0] + if len(search_results) == 1: + edges = search_results[0] - elif len(search_results) > 1 and config.reranker != 'rrf': - logger.exception('Multiple searches enabled without a reranker') - raise Exception('Multiple searches enabled without a reranker') + elif len(search_results) > 1 and config.reranker != 'rrf': + logger.exception('Multiple searches enabled without a reranker') + raise Exception('Multiple searches enabled without a reranker') - elif config.reranker == 'rrf': - edge_uuid_map = {} - search_result_uuids = [] + elif config.reranker == 'rrf': + edge_uuid_map = {} + search_result_uuids = [] - logger.info([[edge.fact for edge in result] for result in search_results]) + logger.info([[edge.fact for edge in result] for result in search_results]) - for result in search_results: - result_uuids = [] - for edge in result: - result_uuids.append(edge.uuid) - edge_uuid_map[edge.uuid] = edge + for result in search_results: + result_uuids = [] + for edge in result: + result_uuids.append(edge.uuid) + edge_uuid_map[edge.uuid] = edge - search_result_uuids.append(result_uuids) + search_result_uuids.append(result_uuids) - search_result_uuids = [[edge.uuid for edge in result] for result in search_results] + search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - reranked_uuids = rrf(search_result_uuids) + reranked_uuids = rrf(search_result_uuids) - reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] - edges.extend(reranked_edges) + reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] + edges.extend(reranked_edges) - context = SearchResults(episodes=episodes, nodes=nodes, edges=edges) + context = SearchResults(episodes=episodes, nodes=nodes, edges=edges) - end = time() + end = time() - logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms') + logger.info(f'search returned context for query {query} in {(end - start) * 1000} ms') - return context + return context diff --git a/core/search/search_utils.py b/core/search/search_utils.py index 6e4b443..97046aa 100644 --- a/core/search/search_utils.py +++ b/core/search/search_utils.py @@ -17,13 +17,13 @@ def parse_db_date(neo_date: neo4j_time.DateTime | None) -> datetime | None: - return neo_date.to_native() if neo_date else None + return neo_date.to_native() if neo_date else None async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]): - episode_uuids = [episode.uuid for episode in episodes] - records, _, _ = await driver.execute_query( - """ + episode_uuids = [episode.uuid for episode in episodes] + records, _, _ = await driver.execute_query( + """ MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids RETURN DISTINCT n.uuid As uuid, @@ -31,28 +31,28 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) n.created_at AS created_at, n.summary AS summary """, - uuids=episode_uuids, - ) + uuids=episode_uuids, + ) - nodes: list[EntityNode] = [] + nodes: list[EntityNode] = [] - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + for record in records: + nodes.append( + EntityNode( + uuid=record['uuid'], + name=record['name'], + labels=['Entity'], + created_at=record['created_at'].to_native(), + summary=record['summary'], + ) + ) - return nodes + return nodes async def bfs(node_ids: list[str], driver: AsyncDriver): - records, _, _ = await driver.execute_query( - """ + records, _, _ = await driver.execute_query( + """ MATCH (n WHERE n.uuid in $node_ids)-[r]->(m) RETURN DISTINCT n.uuid AS source_node_uuid, @@ -72,39 +72,39 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): r.invalid_at AS invalid_at """, - node_ids=node_ids, - ) - - context: dict[str, typing.Any] = {} - - for record in records: - n_uuid = record['source_node_uuid'] - if n_uuid in context: - context[n_uuid]['facts'].append(record['fact']) - else: - context[n_uuid] = { - 'name': record['source_name'], - 'summary': record['source_summary'], - 'facts': [record['fact']], - } - - m_uuid = record['target_node_uuid'] - if m_uuid not in context: - context[m_uuid] = { - 'name': record['target_name'], - 'summary': record['target_summary'], - 'facts': [], - } - logger.info(f'bfs search returned context: {context}') - return context + node_ids=node_ids, + ) + + context: dict[str, typing.Any] = {} + + for record in records: + n_uuid = record['source_node_uuid'] + if n_uuid in context: + context[n_uuid]['facts'].append(record['fact']) + else: + context[n_uuid] = { + 'name': record['source_name'], + 'summary': record['source_summary'], + 'facts': [record['fact']], + } + + m_uuid = record['target_node_uuid'] + if m_uuid not in context: + context[m_uuid] = { + 'name': record['target_name'], + 'summary': record['target_summary'], + 'facts': [], + } + logger.info(f'bfs search returned context: {context}') + return context async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: - # vector similarity search over embedded facts - records, _, _ = await driver.execute_query( - """ + # vector similarity search over embedded facts + records, _, _ = await driver.execute_query( + """ CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector) YIELD relationship AS r, score MATCH (n)-[r:RELATES_TO]->(m) @@ -122,38 +122,38 @@ async def edge_similarity_search( r.invalid_at AS invalid_at ORDER BY score DESC LIMIT $limit """, - search_vector=search_vector, - limit=limit, - ) + search_vector=search_vector, + limit=limit, + ) - edges: list[EntityEdge] = [] + edges: list[EntityEdge] = [] - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) + for record in records: + edge = EntityEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + fact=record['fact'], + name=record['name'], + episodes=record['episodes'], + fact_embedding=record['fact_embedding'], + created_at=record['created_at'].to_native(), + expired_at=parse_db_date(record['expired_at']), + valid_at=parse_db_date(record['valid_at']), + invalid_at=parse_db_date(record['invalid_at']), + ) - edges.append(edge) + edges.append(edge) - return edges + return edges async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: - # vector similarity search over entity names - records, _, _ = await driver.execute_query( - """ + # vector similarity search over entity names + records, _, _ = await driver.execute_query( + """ CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score RETURN @@ -163,32 +163,32 @@ async def entity_similarity_search( n.summary AS summary ORDER BY score DESC """, - search_vector=search_vector, - limit=limit, - ) - nodes: list[EntityNode] = [] + search_vector=search_vector, + limit=limit, + ) + nodes: list[EntityNode] = [] - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + for record in records: + nodes.append( + EntityNode( + uuid=record['uuid'], + name=record['name'], + labels=['Entity'], + created_at=record['created_at'].to_native(), + summary=record['summary'], + ) + ) - return nodes + return nodes async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: - # BM25 search to get top nodes - fuzzy_query = query + '~' - records, _, _ = await driver.execute_query( - """ + # BM25 search to get top nodes + fuzzy_query = query + '~' + records, _, _ = await driver.execute_query( + """ CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score RETURN node.uuid As uuid, @@ -198,33 +198,33 @@ async def entity_fulltext_search( ORDER BY score DESC LIMIT $limit """, - query=fuzzy_query, - limit=limit, - ) - nodes: list[EntityNode] = [] + query=fuzzy_query, + limit=limit, + ) + nodes: list[EntityNode] = [] - for record in records: - nodes.append( - EntityNode( - uuid=record['uuid'], - name=record['name'], - labels=['Entity'], - created_at=record['created_at'].to_native(), - summary=record['summary'], - ) - ) + for record in records: + nodes.append( + EntityNode( + uuid=record['uuid'], + name=record['name'], + labels=['Entity'], + created_at=record['created_at'].to_native(), + summary=record['summary'], + ) + ) - return nodes + return nodes async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: - # fulltext search over facts - fuzzy_query = query + '~' + # fulltext search over facts + fuzzy_query = query + '~' - records, _, _ = await driver.execute_query( - """ + records, _, _ = await driver.execute_query( + """ CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS r, score MATCH (n:Entity)-[r]->(m:Entity) @@ -242,104 +242,104 @@ async def edge_fulltext_search( r.invalid_at AS invalid_at ORDER BY score DESC LIMIT $limit """, - query=fuzzy_query, - limit=limit, - ) + query=fuzzy_query, + limit=limit, + ) - edges: list[EntityEdge] = [] + edges: list[EntityEdge] = [] - for record in records: - edge = EntityEdge( - uuid=record['uuid'], - source_node_uuid=record['source_node_uuid'], - target_node_uuid=record['target_node_uuid'], - fact=record['fact'], - name=record['name'], - episodes=record['episodes'], - fact_embedding=record['fact_embedding'], - created_at=record['created_at'].to_native(), - expired_at=parse_db_date(record['expired_at']), - valid_at=parse_db_date(record['valid_at']), - invalid_at=parse_db_date(record['invalid_at']), - ) + for record in records: + edge = EntityEdge( + uuid=record['uuid'], + source_node_uuid=record['source_node_uuid'], + target_node_uuid=record['target_node_uuid'], + fact=record['fact'], + name=record['name'], + episodes=record['episodes'], + fact_embedding=record['fact_embedding'], + created_at=record['created_at'].to_native(), + expired_at=parse_db_date(record['expired_at']), + valid_at=parse_db_date(record['valid_at']), + invalid_at=parse_db_date(record['invalid_at']), + ) - edges.append(edge) + edges.append(edge) - return edges + return edges async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: - start = time() - relevant_nodes: list[EntityNode] = [] - relevant_node_uuids = set() + start = time() + relevant_nodes: list[EntityNode] = [] + relevant_node_uuids = set() - results = await asyncio.gather( - *[entity_fulltext_search(node.name, driver) for node in nodes], - *[ - entity_similarity_search(node.name_embedding, driver) - for node in nodes - if node.name_embedding is not None - ], - ) + results = await asyncio.gather( + *[entity_fulltext_search(node.name, driver) for node in nodes], + *[ + entity_similarity_search(node.name_embedding, driver) + for node in nodes + if node.name_embedding is not None + ], + ) - for result in results: - for node in result: - if node.uuid in relevant_node_uuids: - continue + for result in results: + for node in result: + if node.uuid in relevant_node_uuids: + continue - relevant_node_uuids.add(node.uuid) - relevant_nodes.append(node) + relevant_node_uuids.add(node.uuid) + relevant_nodes.append(node) - end = time() - logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms') + end = time() + logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms') - return relevant_nodes + return relevant_nodes async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, + edges: list[EntityEdge], + driver: AsyncDriver, ) -> list[EntityEdge]: - start = time() - relevant_edges: list[EntityEdge] = [] - relevant_edge_uuids = set() + start = time() + relevant_edges: list[EntityEdge] = [] + relevant_edge_uuids = set() - results = await asyncio.gather( - *[ - edge_similarity_search(edge.fact_embedding, driver) - for edge in edges - if edge.fact_embedding is not None - ], - *[edge_fulltext_search(edge.fact, driver) for edge in edges], - ) + results = await asyncio.gather( + *[ + edge_similarity_search(edge.fact_embedding, driver) + for edge in edges + if edge.fact_embedding is not None + ], + *[edge_fulltext_search(edge.fact, driver) for edge in edges], + ) - for result in results: - for edge in result: - if edge.uuid in relevant_edge_uuids: - continue + for result in results: + for edge in result: + if edge.uuid in relevant_edge_uuids: + continue - relevant_edge_uuids.add(edge.uuid) - relevant_edges.append(edge) + relevant_edge_uuids.add(edge.uuid) + relevant_edges.append(edge) - end = time() - logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms') + end = time() + logger.info(f'Found relevant edges: {relevant_edge_uuids} in {(end - start) * 1000} ms') - return relevant_edges + return relevant_edges # takes in a list of rankings of uuids def rrf(results: list[list[str]], rank_const=1) -> list[str]: - scores: dict[str, int] = defaultdict(int) - for result in results: - for i, uuid in enumerate(result): - scores[uuid] += 1 / (i + rank_const) + scores: dict[str, int] = defaultdict(int) + for result in results: + for i, uuid in enumerate(result): + scores[uuid] += 1 / (i + rank_const) - scored_uuids = [term for term in scores.items()] - scored_uuids.sort(reverse=True, key=lambda term: term[1]) + scored_uuids = [term for term in scores.items()] + scored_uuids.sort(reverse=True, key=lambda term: term[1]) - sorted_uuids = [term[0] for term in scored_uuids] + sorted_uuids = [term[0] for term in scored_uuids] - return sorted_uuids + return sorted_uuids diff --git a/core/utils/__init__.py b/core/utils/__init__.py index 7978529..5464234 100644 --- a/core/utils/__init__.py +++ b/core/utils/__init__.py @@ -1,15 +1,15 @@ from .maintenance import ( - build_episodic_edges, - clear_data, - extract_edges, - extract_nodes, - retrieve_episodes, + build_episodic_edges, + clear_data, + extract_edges, + extract_nodes, + retrieve_episodes, ) __all__ = [ - 'extract_edges', - 'build_episodic_edges', - 'extract_nodes', - 'clear_data', - 'retrieve_episodes', + 'extract_edges', + 'build_episodic_edges', + 'extract_nodes', + 'clear_data', + 'retrieve_episodes', ] diff --git a/core/utils/bulk_utils.py b/core/utils/bulk_utils.py index 197d4aa..9b2202f 100644 --- a/core/utils/bulk_utils.py +++ b/core/utils/bulk_utils.py @@ -36,7 +36,7 @@ class BulkEpisode(BaseModel): async def retrieve_previous_episodes_bulk( - driver: AsyncDriver, episodes: list[EpisodicNode] + driver: AsyncDriver, episodes: list[EpisodicNode] ) -> list[tuple[EpisodicNode, list[EpisodicNode]]]: previous_episodes_list = await asyncio.gather( *[ @@ -52,7 +52,7 @@ async def retrieve_previous_episodes_bulk( async def extract_nodes_and_edges_bulk( - llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] + llm_client: LLMClient, episode_tuples: list[tuple[EpisodicNode, list[EpisodicNode]]] ) -> tuple[list[EntityNode], list[EntityEdge], list[EpisodicEdge]]: extracted_nodes_bulk = await asyncio.gather( *[ @@ -89,9 +89,9 @@ async def extract_nodes_and_edges_bulk( async def dedupe_nodes_bulk( - driver: AsyncDriver, - llm_client: LLMClient, - extracted_nodes: list[EntityNode], + driver: AsyncDriver, + llm_client: LLMClient, + extracted_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: # Compress nodes nodes, uuid_map = node_name_match(extracted_nodes) @@ -110,7 +110,7 @@ async def dedupe_nodes_bulk( async def dedupe_edges_bulk( - driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] + driver: AsyncDriver, llm_client: LLMClient, extracted_edges: list[EntityEdge] ) -> list[EntityEdge]: # Compress edges compressed_edges = await compress_edges(llm_client, extracted_edges) @@ -136,7 +136,7 @@ def node_name_match(nodes: list[EntityNode]) -> tuple[list[EntityNode], dict[str async def compress_nodes( - llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] + llm_client: LLMClient, nodes: list[EntityNode], uuid_map: dict[str, str] ) -> tuple[list[EntityNode], dict[str, str]]: if len(nodes) == 0: return nodes, uuid_map @@ -144,7 +144,7 @@ async def compress_nodes( anchor = nodes[0] nodes.sort(key=lambda node: dot(anchor.name_embedding or [], node.name_embedding or [])) - node_chunks = [nodes[i: i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] + node_chunks = [nodes[i : i + CHUNK_SIZE] for i in range(0, len(nodes), CHUNK_SIZE)] results = await asyncio.gather(*[dedupe_node_list(llm_client, chunk) for chunk in node_chunks]) @@ -167,9 +167,11 @@ async def compress_edges(llm_client: LLMClient, edges: list[EntityEdge]) -> list return edges anchor = edges[0] - edges.sort(key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or [])) + edges.sort( + key=lambda embedding: dot(anchor.fact_embedding or [], embedding.fact_embedding or []) + ) - edge_chunks = [edges[i: i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] + edge_chunks = [edges[i : i + CHUNK_SIZE] for i in range(0, len(edges), CHUNK_SIZE)] results = await asyncio.gather(*[dedupe_edge_list(llm_client, chunk) for chunk in edge_chunks]) diff --git a/core/utils/maintenance/__init__.py b/core/utils/maintenance/__init__.py index d552dbb..553a203 100644 --- a/core/utils/maintenance/__init__.py +++ b/core/utils/maintenance/__init__.py @@ -1,16 +1,16 @@ from .edge_operations import build_episodic_edges, extract_edges from .graph_data_operations import ( - clear_data, - retrieve_episodes, + clear_data, + retrieve_episodes, ) from .node_operations import extract_nodes from .temporal_operations import invalidate_edges __all__ = [ - 'extract_edges', - 'build_episodic_edges', - 'extract_nodes', - 'clear_data', - 'retrieve_episodes', - 'invalidate_edges', + 'extract_edges', + 'build_episodic_edges', + 'extract_nodes', + 'clear_data', + 'retrieve_episodes', + 'invalidate_edges', ] diff --git a/core/utils/maintenance/edge_operations.py b/core/utils/maintenance/edge_operations.py index bdcfef6..ec436ab 100644 --- a/core/utils/maintenance/edge_operations.py +++ b/core/utils/maintenance/edge_operations.py @@ -12,142 +12,142 @@ def build_episodic_edges( - entity_nodes: List[EntityNode], - episode: EpisodicNode, - created_at: datetime, + entity_nodes: List[EntityNode], + episode: EpisodicNode, + created_at: datetime, ) -> List[EpisodicEdge]: - edges: List[EpisodicEdge] = [] + edges: List[EpisodicEdge] = [] - for node in entity_nodes: - edge = EpisodicEdge( - source_node_uuid=episode.uuid, - target_node_uuid=node.uuid, - created_at=created_at, - ) - edges.append(edge) + for node in entity_nodes: + edge = EpisodicEdge( + source_node_uuid=episode.uuid, + target_node_uuid=node.uuid, + created_at=created_at, + ) + edges.append(edge) - return edges + return edges async def extract_edges( - llm_client: LLMClient, - episode: EpisodicNode, - nodes: list[EntityNode], - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + nodes: list[EntityNode], + previous_episodes: list[EpisodicNode], ) -> list[EntityEdge]: - start = time() - - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'nodes': [ - {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes - ], - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context)) - edges_data = llm_response.get('edges', []) - - end = time() - logger.info(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms') - - # Convert the extracted data into EntityEdge objects - edges = [] - for edge_data in edges_data: - if edge_data['target_node_uuid'] and edge_data['source_node_uuid']: - edge = EntityEdge( - source_node_uuid=edge_data['source_node_uuid'], - target_node_uuid=edge_data['target_node_uuid'], - name=edge_data['relation_type'], - fact=edge_data['fact'], - episodes=[episode.uuid], - created_at=datetime.now(), - valid_at=None, - invalid_at=None, - ) - edges.append(edge) - logger.info( - f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})' - ) - - return edges + start = time() + + # Prepare context for LLM + context = { + 'episode_content': episode.content, + 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), + 'nodes': [ + {'uuid': node.uuid, 'name': node.name, 'summary': node.summary} for node in nodes + ], + 'previous_episodes': [ + { + 'content': ep.content, + 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response(prompt_library.extract_edges.v2(context)) + edges_data = llm_response.get('edges', []) + + end = time() + logger.info(f'Extracted new edges: {edges_data} in {(end - start) * 1000} ms') + + # Convert the extracted data into EntityEdge objects + edges = [] + for edge_data in edges_data: + if edge_data['target_node_uuid'] and edge_data['source_node_uuid']: + edge = EntityEdge( + source_node_uuid=edge_data['source_node_uuid'], + target_node_uuid=edge_data['target_node_uuid'], + name=edge_data['relation_type'], + fact=edge_data['fact'], + episodes=[episode.uuid], + created_at=datetime.now(), + valid_at=None, + invalid_at=None, + ) + edges.append(edge) + logger.info( + f'Created new edge: {edge.name} from (UUID: {edge.source_node_uuid}) to (UUID: {edge.target_node_uuid})' + ) + + return edges def create_edge_identifier( - source_node: EntityNode, edge: EntityEdge, target_node: EntityNode + source_node: EntityNode, edge: EntityEdge, target_node: EntityNode ) -> str: - return f'{source_node.name}-{edge.name}-{target_node.name}' + return f'{source_node.name}-{edge.name}-{target_node.name}' async def dedupe_extracted_edges( - llm_client: LLMClient, - extracted_edges: list[EntityEdge], - existing_edges: list[EntityEdge], + llm_client: LLMClient, + extracted_edges: list[EntityEdge], + existing_edges: list[EntityEdge], ) -> list[EntityEdge]: - # Create edge map - edge_map = {} - for edge in extracted_edges: - edge_map[edge.uuid] = edge - - # Prepare context for LLM - context = { - 'extracted_edges': [ - {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges - ], - 'existing_edges': [ - {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges - ], - } - - llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context)) - unique_edge_data = llm_response.get('unique_facts', []) - logger.info(f'Extracted unique edges: {unique_edge_data}') - - # Get full edge data - edges = [] - for unique_edge in unique_edge_data: - edge = edge_map[unique_edge['uuid']] - edges.append(edge) - - return edges + # Create edge map + edge_map = {} + for edge in extracted_edges: + edge_map[edge.uuid] = edge + + # Prepare context for LLM + context = { + 'extracted_edges': [ + {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in extracted_edges + ], + 'existing_edges': [ + {'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in existing_edges + ], + } + + llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v1(context)) + unique_edge_data = llm_response.get('unique_facts', []) + logger.info(f'Extracted unique edges: {unique_edge_data}') + + # Get full edge data + edges = [] + for unique_edge in unique_edge_data: + edge = edge_map[unique_edge['uuid']] + edges.append(edge) + + return edges async def dedupe_edge_list( - llm_client: LLMClient, - edges: list[EntityEdge], + llm_client: LLMClient, + edges: list[EntityEdge], ) -> list[EntityEdge]: - start = time() + start = time() - # Create edge map - edge_map = {} - for edge in edges: - edge_map[edge.uuid] = edge + # Create edge map + edge_map = {} + for edge in edges: + edge_map[edge.uuid] = edge - # Prepare context for LLM - context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]} + # Prepare context for LLM + context = {'edges': [{'uuid': edge.uuid, 'fact': edge.fact} for edge in edges]} - llm_response = await llm_client.generate_response( - prompt_library.dedupe_edges.edge_list(context) - ) - unique_edges_data = llm_response.get('unique_facts', []) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_edges.edge_list(context) + ) + unique_edges_data = llm_response.get('unique_facts', []) - end = time() - logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ') + end = time() + logger.info(f'Extracted edge duplicates: {unique_edges_data} in {(end - start) * 1000} ms ') - # Get full edge data - unique_edges = [] - for edge_data in unique_edges_data: - uuid = edge_data['uuid'] - edge = edge_map[uuid] - edge.fact = edge_data['fact'] - unique_edges.append(edge) + # Get full edge data + unique_edges = [] + for edge_data in unique_edges_data: + uuid = edge_data['uuid'] + edge = edge_map[uuid] + edge.fact = edge_data['fact'] + unique_edges.append(edge) - return unique_edges + return unique_edges diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py index 67579ca..4849b23 100644 --- a/core/utils/maintenance/graph_data_operations.py +++ b/core/utils/maintenance/graph_data_operations.py @@ -13,29 +13,29 @@ async def build_indices_and_constraints(driver: AsyncDriver): - range_indices: list[LiteralString] = [ - 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', - 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', - 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', - 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', - 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', - 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', - 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', - 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', - 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', - 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', - 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', - 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', - 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', - ] + range_indices: list[LiteralString] = [ + 'CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)', + 'CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)', + 'CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)', + 'CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)', + 'CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)', + 'CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)', + 'CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)', + 'CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)', + 'CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)', + 'CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)', + 'CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)', + 'CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)', + 'CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)', + ] - fulltext_indices: list[LiteralString] = [ - 'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]', - 'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]', - ] + fulltext_indices: list[LiteralString] = [ + 'CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]', + 'CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]', + ] - vector_indices: list[LiteralString] = [ - """ + vector_indices: list[LiteralString] = [ + """ CREATE VECTOR INDEX fact_embedding IF NOT EXISTS FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding) OPTIONS {indexConfig: { @@ -43,7 +43,7 @@ async def build_indices_and_constraints(driver: AsyncDriver): `vector.similarity_function`: 'cosine' }} """, - """ + """ CREATE VECTOR INDEX name_embedding IF NOT EXISTS FOR (n:Entity) ON (n.name_embedding) OPTIONS {indexConfig: { @@ -51,29 +51,29 @@ async def build_indices_and_constraints(driver: AsyncDriver): `vector.similarity_function`: 'cosine' }} """, - ] - index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices + ] + index_queries: list[LiteralString] = range_indices + fulltext_indices + vector_indices - await asyncio.gather(*[driver.execute_query(query) for query in index_queries]) + await asyncio.gather(*[driver.execute_query(query) for query in index_queries]) async def clear_data(driver: AsyncDriver): - async with driver.session() as session: + async with driver.session() as session: - async def delete_all(tx): - await tx.run('MATCH (n) DETACH DELETE n') + async def delete_all(tx): + await tx.run('MATCH (n) DETACH DELETE n') - await session.execute_write(delete_all) + await session.execute_write(delete_all) async def retrieve_episodes( - driver: AsyncDriver, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, + driver: AsyncDriver, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: - """Retrieve the last n episodic nodes from the graph""" - result = await driver.execute_query( - """ + """Retrieve the last n episodic nodes from the graph""" + result = await driver.execute_query( + """ MATCH (e:Episodic) WHERE e.valid_at <= $reference_time RETURN e.content as content, e.created_at as created_at, @@ -85,21 +85,21 @@ async def retrieve_episodes( ORDER BY e.created_at DESC LIMIT $num_episodes """, - reference_time=reference_time, - num_episodes=last_n, - ) - episodes = [ - EpisodicNode( - content=record['content'], - created_at=datetime.fromtimestamp( - record['created_at'].to_native().timestamp(), timezone.utc - ), - valid_at=(record['valid_at'].to_native()), - uuid=record['uuid'], - source=record['source'], - name=record['name'], - source_description=record['source_description'], - ) - for record in result.records - ] - return list(reversed(episodes)) # Return in chronological order + reference_time=reference_time, + num_episodes=last_n, + ) + episodes = [ + EpisodicNode( + content=record['content'], + created_at=datetime.fromtimestamp( + record['created_at'].to_native().timestamp(), timezone.utc + ), + valid_at=(record['valid_at'].to_native()), + uuid=record['uuid'], + source=record['source'], + name=record['name'], + source_description=record['source_description'], + ) + for record in result.records + ] + return list(reversed(episodes)) # Return in chronological order diff --git a/core/utils/maintenance/node_operations.py b/core/utils/maintenance/node_operations.py index e7c3e72..e13a5a1 100644 --- a/core/utils/maintenance/node_operations.py +++ b/core/utils/maintenance/node_operations.py @@ -10,145 +10,145 @@ async def extract_nodes( - llm_client: LLMClient, - episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityNode]: - start = time() - - # Prepare context for LLM - context = { - 'episode_content': episode.content, - 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), - 'previous_episodes': [ - { - 'content': ep.content, - 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, - } - for ep in previous_episodes - ], - } - - llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v3(context)) - new_nodes_data = llm_response.get('new_nodes', []) - - end = time() - logger.info(f'Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms') - # Convert the extracted data into EntityNode objects - new_nodes = [] - for node_data in new_nodes_data: - new_node = EntityNode( - name=node_data['name'], - labels=node_data['labels'], - summary=node_data['summary'], - created_at=datetime.now(), - ) - new_nodes.append(new_node) - logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') - - return new_nodes + start = time() + + # Prepare context for LLM + context = { + 'episode_content': episode.content, + 'episode_timestamp': (episode.valid_at.isoformat() if episode.valid_at else None), + 'previous_episodes': [ + { + 'content': ep.content, + 'timestamp': ep.valid_at.isoformat() if ep.valid_at else None, + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response(prompt_library.extract_nodes.v3(context)) + new_nodes_data = llm_response.get('new_nodes', []) + + end = time() + logger.info(f'Extracted new nodes: {new_nodes_data} in {(end - start) * 1000} ms') + # Convert the extracted data into EntityNode objects + new_nodes = [] + for node_data in new_nodes_data: + new_node = EntityNode( + name=node_data['name'], + labels=node_data['labels'], + summary=node_data['summary'], + created_at=datetime.now(), + ) + new_nodes.append(new_node) + logger.info(f'Created new node: {new_node.name} (UUID: {new_node.uuid})') + + return new_nodes async def dedupe_extracted_nodes( - llm_client: LLMClient, - extracted_nodes: list[EntityNode], - existing_nodes: list[EntityNode], + llm_client: LLMClient, + extracted_nodes: list[EntityNode], + existing_nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str], list[EntityNode]]: - start = time() + start = time() - # build existing node map - node_map: dict[str, EntityNode] = {} - for node in existing_nodes: - node_map[node.name] = node + # build existing node map + node_map: dict[str, EntityNode] = {} + for node in existing_nodes: + node_map[node.name] = node - # Temp hack - new_nodes_map: dict[str, EntityNode] = {} - for node in extracted_nodes: - new_nodes_map[node.name] = node + # Temp hack + new_nodes_map: dict[str, EntityNode] = {} + for node in extracted_nodes: + new_nodes_map[node.name] = node - # Prepare context for LLM - existing_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in existing_nodes - ] + # Prepare context for LLM + existing_nodes_context = [ + {'name': node.name, 'summary': node.summary} for node in existing_nodes + ] - extracted_nodes_context = [ - {'name': node.name, 'summary': node.summary} for node in extracted_nodes - ] + extracted_nodes_context = [ + {'name': node.name, 'summary': node.summary} for node in extracted_nodes + ] - context = { - 'existing_nodes': existing_nodes_context, - 'extracted_nodes': extracted_nodes_context, - } + context = { + 'existing_nodes': existing_nodes_context, + 'extracted_nodes': extracted_nodes_context, + } - llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v2(context)) + llm_response = await llm_client.generate_response(prompt_library.dedupe_nodes.v2(context)) - duplicate_data = llm_response.get('duplicates', []) + duplicate_data = llm_response.get('duplicates', []) - end = time() - logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms') + end = time() + logger.info(f'Deduplicated nodes: {duplicate_data} in {(end - start) * 1000} ms') - uuid_map: dict[str, str] = {} - for duplicate in duplicate_data: - uuid = new_nodes_map[duplicate['name']].uuid - uuid_value = node_map[duplicate['duplicate_of']].uuid - uuid_map[uuid] = uuid_value + uuid_map: dict[str, str] = {} + for duplicate in duplicate_data: + uuid = new_nodes_map[duplicate['name']].uuid + uuid_value = node_map[duplicate['duplicate_of']].uuid + uuid_map[uuid] = uuid_value - nodes: list[EntityNode] = [] - brand_new_nodes: list[EntityNode] = [] - for node in extracted_nodes: - if node.uuid in uuid_map: - existing_uuid = uuid_map[node.uuid] - # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, - # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? - # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) - existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None) - if existing_node: - nodes.append(existing_node) + nodes: list[EntityNode] = [] + brand_new_nodes: list[EntityNode] = [] + for node in extracted_nodes: + if node.uuid in uuid_map: + existing_uuid = uuid_map[node.uuid] + # TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes, + # can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please? + # find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value) + existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None) + if existing_node: + nodes.append(existing_node) - continue - brand_new_nodes.append(node) - nodes.append(node) + continue + brand_new_nodes.append(node) + nodes.append(node) - return nodes, uuid_map, brand_new_nodes + return nodes, uuid_map, brand_new_nodes async def dedupe_node_list( - llm_client: LLMClient, - nodes: list[EntityNode], + llm_client: LLMClient, + nodes: list[EntityNode], ) -> tuple[list[EntityNode], dict[str, str]]: - start = time() + start = time() - # build node map - node_map = {} - for node in nodes: - node_map[node.name] = node + # build node map + node_map = {} + for node in nodes: + node_map[node.name] = node - # Prepare context for LLM - nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes] + # Prepare context for LLM + nodes_context = [{'name': node.name, 'summary': node.summary} for node in nodes] - context = { - 'nodes': nodes_context, - } + context = { + 'nodes': nodes_context, + } - llm_response = await llm_client.generate_response( - prompt_library.dedupe_nodes.node_list(context) - ) + llm_response = await llm_client.generate_response( + prompt_library.dedupe_nodes.node_list(context) + ) - nodes_data = llm_response.get('nodes', []) + nodes_data = llm_response.get('nodes', []) - end = time() - logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms') + end = time() + logger.info(f'Deduplicated nodes: {nodes_data} in {(end - start) * 1000} ms') - # Get full node data - unique_nodes = [] - uuid_map: dict[str, str] = {} - for node_data in nodes_data: - node = node_map[node_data['names'][0]] - unique_nodes.append(node) + # Get full node data + unique_nodes = [] + uuid_map: dict[str, str] = {} + for node_data in nodes_data: + node = node_map[node_data['names'][0]] + unique_nodes.append(node) - for name in node_data['names'][1:]: - uuid = node_map[name].uuid - uuid_value = node_map[node_data['names'][0]].uuid - uuid_map[uuid] = uuid_value + for name in node_data['names'][1:]: + uuid = node_map[name].uuid + uuid_value = node_map[node_data['names'][0]].uuid + uuid_map[uuid] = uuid_value - return unique_nodes, uuid_map + return unique_nodes, uuid_map diff --git a/core/utils/maintenance/temporal_operations.py b/core/utils/maintenance/temporal_operations.py index 8634f66..a13002b 100644 --- a/core/utils/maintenance/temporal_operations.py +++ b/core/utils/maintenance/temporal_operations.py @@ -13,109 +13,153 @@ def extract_node_and_edge_triplets( - edges: list[EntityEdge], nodes: list[EntityNode] + edges: list[EntityEdge], nodes: list[EntityNode] ) -> list[NodeEdgeNodeTriplet]: - return [extract_node_edge_node_triplet(edge, nodes) for edge in edges] + return [extract_node_edge_node_triplet(edge, nodes) for edge in edges] def extract_node_edge_node_triplet( - edge: EntityEdge, nodes: list[EntityNode] + edge: EntityEdge, nodes: list[EntityNode] ) -> NodeEdgeNodeTriplet: - source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) - target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) - if not source_node or not target_node: - raise ValueError(f'Source or target node not found for edge {edge.uuid}') - return (source_node, edge, target_node) + source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) + target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) + if not source_node or not target_node: + raise ValueError(f'Source or target node not found for edge {edge.uuid}') + return (source_node, edge, target_node) def prepare_edges_for_invalidation( - existing_edges: list[EntityEdge], - new_edges: list[EntityEdge], - nodes: list[EntityNode], + existing_edges: list[EntityEdge], + new_edges: list[EntityEdge], + nodes: list[EntityNode], ) -> tuple[list[NodeEdgeNodeTriplet], list[NodeEdgeNodeTriplet]]: - existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = [] - new_edges_with_nodes: list[NodeEdgeNodeTriplet] = [] + existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet] = [] + new_edges_with_nodes: list[NodeEdgeNodeTriplet] = [] - for edge_list, result_list in [ - (existing_edges, existing_edges_pending_invalidation), - (new_edges, new_edges_with_nodes), - ]: - for edge in edge_list: - source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) - target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) + for edge_list, result_list in [ + (existing_edges, existing_edges_pending_invalidation), + (new_edges, new_edges_with_nodes), + ]: + for edge in edge_list: + source_node = next((node for node in nodes if node.uuid == edge.source_node_uuid), None) + target_node = next((node for node in nodes if node.uuid == edge.target_node_uuid), None) - if source_node and target_node: - result_list.append((source_node, edge, target_node)) + if source_node and target_node: + result_list.append((source_node, edge, target_node)) - return existing_edges_pending_invalidation, new_edges_with_nodes + return existing_edges_pending_invalidation, new_edges_with_nodes async def invalidate_edges( - llm_client: LLMClient, - existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + llm_client: LLMClient, + existing_edges_pending_invalidation: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> list[EntityEdge]: - invalidated_edges = [] # TODO: this is not yet used? + invalidated_edges = [] # TODO: this is not yet used? - context = prepare_invalidation_context( - existing_edges_pending_invalidation, - new_edges, - current_episode, - previous_episodes, - ) - logger.info(prompt_library.invalidate_edges.v1(context)) - llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context)) - logger.info(f'invalidate_edges LLM response: {llm_response}') + context = prepare_invalidation_context( + existing_edges_pending_invalidation, + new_edges, + current_episode, + previous_episodes, + ) + logger.info(prompt_library.invalidate_edges.v1(context)) + llm_response = await llm_client.generate_response(prompt_library.invalidate_edges.v1(context)) + logger.info(f'invalidate_edges LLM response: {llm_response}') - edges_to_invalidate = llm_response.get('invalidated_edges', []) - invalidated_edges = process_edge_invalidation_llm_response( - edges_to_invalidate, existing_edges_pending_invalidation - ) + edges_to_invalidate = llm_response.get('invalidated_edges', []) + invalidated_edges = process_edge_invalidation_llm_response( + edges_to_invalidate, existing_edges_pending_invalidation + ) - return invalidated_edges + return invalidated_edges def prepare_invalidation_context( - existing_edges: list[NodeEdgeNodeTriplet], - new_edges: list[NodeEdgeNodeTriplet], - current_episode: EpisodicNode, - previous_episodes: list[EpisodicNode], + existing_edges: list[NodeEdgeNodeTriplet], + new_edges: list[NodeEdgeNodeTriplet], + current_episode: EpisodicNode, + previous_episodes: list[EpisodicNode], ) -> dict: - return { - 'existing_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' - for source_node, edge, target_node in sorted( - existing_edges, key=lambda x: x[1].created_at, reverse=True - ) - ], - 'new_edges': [ - f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' - for source_node, edge, target_node in sorted( - new_edges, key=lambda x: x[1].created_at, reverse=True - ) - ], - 'current_episode': current_episode.content, - 'previous_episodes': [episode.content for episode in previous_episodes], - } + return { + 'existing_edges': [ + f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' + for source_node, edge, target_node in sorted( + existing_edges, key=lambda x: x[1].created_at, reverse=True + ) + ], + 'new_edges': [ + f'{edge.uuid} | {source_node.name} - {edge.name} - {target_node.name} (Fact: {edge.fact}) ({edge.created_at.isoformat()})' + for source_node, edge, target_node in sorted( + new_edges, key=lambda x: x[1].created_at, reverse=True + ) + ], + 'current_episode': current_episode.content, + 'previous_episodes': [episode.content for episode in previous_episodes], + } def process_edge_invalidation_llm_response( - edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] + edges_to_invalidate: List[dict], existing_edges: List[NodeEdgeNodeTriplet] ) -> List[EntityEdge]: - invalidated_edges = [] - for edge_to_invalidate in edges_to_invalidate: - edge_uuid = edge_to_invalidate['edge_uuid'] - edge_to_update = next( - (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), - None, - ) - if edge_to_update: - edge_to_update.expired_at = datetime.now() - edge_to_update.fact = edge_to_invalidate['fact'] - invalidated_edges.append(edge_to_update) - logger.info( - f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" - ) - return invalidated_edges + invalidated_edges = [] + for edge_to_invalidate in edges_to_invalidate: + edge_uuid = edge_to_invalidate['edge_uuid'] + edge_to_update = next( + (edge for _, edge, _ in existing_edges if edge.uuid == edge_uuid), + None, + ) + if edge_to_update: + edge_to_update.expired_at = datetime.now() + edge_to_update.fact = edge_to_invalidate['fact'] + invalidated_edges.append(edge_to_update) + logger.info( + f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}" + ) + return invalidated_edges + + +async def extract_edge_dates( + llm_client: LLMClient, + edge_triplet: NodeEdgeNodeTriplet, + reference_time: datetime, + current_episode: EpisodicNode, + previous_episodes: List[EpisodicNode], +) -> tuple[datetime | None, datetime | None, str]: + source_node, edge, target_node = edge_triplet + + context = { + 'source_node': source_node.name, + 'edge_name': edge.name, + 'target_node': target_node.name, + 'edge_fact': edge.fact, + 'current_episode': current_episode.content, + 'previous_episodes': [ep.content for ep in previous_episodes], + 'reference_timestamp': reference_time.isoformat(), + } + llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context)) + + valid_at = llm_response.get('valid_at') + invalid_at = llm_response.get('invalid_at') + explanation = llm_response.get('explanation', '') + + valid_at_datetime = None + invalid_at_datetime = None + + if valid_at and valid_at != '': + try: + valid_at_datetime = datetime.fromisoformat(valid_at.replace('Z', '+00:00')) + except ValueError as e: + logger.error(f'Error parsing valid_at date: {e}. Input: {valid_at}') + + if invalid_at and invalid_at != '': + try: + invalid_at_datetime = datetime.fromisoformat(invalid_at.replace('Z', '+00:00')) + except ValueError as e: + logger.error(f'Error parsing invalid_at date: {e}. Input: {invalid_at}') + + logger.info(f'Edge date extraction explanation: {explanation}') + + return valid_at_datetime, invalid_at_datetime, explanation diff --git a/core/utils/utils.py b/core/utils/utils.py index 8777fa6..62a4ea9 100644 --- a/core/utils/utils.py +++ b/core/utils/utils.py @@ -7,17 +7,17 @@ def build_episodic_edges( - entity_nodes: list[EntityNode], episode: EpisodicNode + entity_nodes: list[EntityNode], episode: EpisodicNode ) -> list[EpisodicEdge]: - edges: list[EpisodicEdge] = [] + edges: list[EpisodicEdge] = [] - for node in entity_nodes: - edges.append( - EpisodicEdge( - source_node_uuid=episode.uuid, - target_node_uuid=node.uuid, - created_at=episode.created_at, - ) - ) + for node in entity_nodes: + edges.append( + EpisodicEdge( + source_node_uuid=episode.uuid, + target_node_uuid=node.uuid, + created_at=episode.created_at, + ) + ) - return edges + return edges diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index e8f9999..08ec32b 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -18,54 +18,54 @@ def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - # Add formatter to console handler - console_handler.setFormatter(formatter) + # Add formatter to console handler + console_handler.setFormatter(formatter) - # Add console handler to logger - logger.addHandler(console_handler) + # Add console handler to logger + logger.addHandler(console_handler) - return logger + return logger async def main(use_bulk: bool = True): - setup_logging() - client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - await clear_data(client.driver) - await client.build_indices_and_constraints() - messages = parse_podcast_messages() - - if not use_bulk: - for i, message in enumerate(messages[3:14]): - await client.add_episode( - name=f'Message {i}', - episode_body=f'{message.speaker_name} ({message.role}): {message.content}', - reference_time=message.actual_timestamp, - source_description='Podcast Transcript', - ) - - episodes: list[BulkEpisode] = [ - BulkEpisode( - name=f'Message {i}', - content=f'{message.speaker_name} ({message.role}): {message.content}', - source_description='Podcast Transcript', - episode_type='string', - reference_time=message.actual_timestamp, - ) - for i, message in enumerate(messages[3:14]) - ] - - await client.add_episode_bulk(episodes) + setup_logging() + client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + await clear_data(client.driver) + await client.build_indices_and_constraints() + messages = parse_podcast_messages() + + if not use_bulk: + for i, message in enumerate(messages[3:14]): + await client.add_episode( + name=f'Message {i}', + episode_body=f'{message.speaker_name} ({message.role}): {message.content}', + reference_time=message.actual_timestamp, + source_description='Podcast Transcript', + ) + + episodes: list[BulkEpisode] = [ + BulkEpisode( + name=f'Message {i}', + content=f'{message.speaker_name} ({message.role}): {message.content}', + source_description='Podcast Transcript', + episode_type='string', + reference_time=message.actual_timestamp, + ) + for i, message in enumerate(messages[3:14]) + ] + + await client.add_episode_bulk(episodes) asyncio.run(main(True)) diff --git a/examples/podcast/transcript_parser.py b/examples/podcast/transcript_parser.py index 199dc0a..10595e5 100644 --- a/examples/podcast/transcript_parser.py +++ b/examples/podcast/transcript_parser.py @@ -7,119 +7,119 @@ class Speaker(BaseModel): - index: int - name: str - role: str + index: int + name: str + role: str class ParsedMessage(BaseModel): - speaker_index: int - speaker_name: str - role: str - relative_timestamp: str - actual_timestamp: datetime - content: str + speaker_index: int + speaker_name: str + role: str + relative_timestamp: str + actual_timestamp: datetime + content: str def parse_timestamp(timestamp: str) -> timedelta: - if 'm' in timestamp: - match = re.match(r'(\d+)m(?:\s*(\d+)s)?', timestamp) - if match: - minutes = int(match.group(1)) - seconds = int(match.group(2)) if match.group(2) else 0 - return timedelta(minutes=minutes, seconds=seconds) - elif 's' in timestamp: - match = re.match(r'(\d+)s', timestamp) - if match: - seconds = int(match.group(1)) - return timedelta(seconds=seconds) - return timedelta() # Return 0 duration if parsing fails + if 'm' in timestamp: + match = re.match(r'(\d+)m(?:\s*(\d+)s)?', timestamp) + if match: + minutes = int(match.group(1)) + seconds = int(match.group(2)) if match.group(2) else 0 + return timedelta(minutes=minutes, seconds=seconds) + elif 's' in timestamp: + match = re.match(r'(\d+)s', timestamp) + if match: + seconds = int(match.group(1)) + return timedelta(seconds=seconds) + return timedelta() # Return 0 duration if parsing fails def parse_conversation_file(file_path: str, speakers: List[Speaker]) -> list[ParsedMessage]: - with open(file_path) as file: - content = file.read() - - messages = content.split('\n\n') - speaker_dict = {speaker.index: speaker for speaker in speakers} - - parsed_messages: list[ParsedMessage] = [] - - # Find the last timestamp to determine podcast duration - last_timestamp = timedelta() - for message in reversed(messages): - lines = message.strip().split('\n') - if lines: - first_line = lines[0] - parts = first_line.split(':', 1) - if len(parts) == 2: - header = parts[0] - header_parts = header.split() - if len(header_parts) >= 2: - timestamp = header_parts[1].strip('()') - last_timestamp = parse_timestamp(timestamp) - break - - # Calculate the start time - now = datetime.now() - podcast_start_time = now - last_timestamp - - for message in messages: - lines = message.strip().split('\n') - if lines: - first_line = lines[0] - parts = first_line.split(':', 1) - if len(parts) == 2: - header, content = parts - header_parts = header.split() - if len(header_parts) >= 2: - speaker_index = int(header_parts[0]) - timestamp = header_parts[1].strip('()') - - if len(lines) > 1: - content += '\n' + '\n'.join(lines[1:]) - - delta = parse_timestamp(timestamp) - actual_time = podcast_start_time + delta - - speaker = speaker_dict.get(speaker_index) - if speaker: - speaker_name = speaker.name - role = speaker.role - else: - speaker_name = f'Unknown Speaker {speaker_index}' - role = 'Unknown' - - parsed_messages.append( - ParsedMessage( - speaker_index=speaker_index, - speaker_name=speaker_name, - role=role, - relative_timestamp=timestamp, - actual_timestamp=actual_time, - content=content.strip(), - ) - ) - - return parsed_messages + with open(file_path) as file: + content = file.read() + + messages = content.split('\n\n') + speaker_dict = {speaker.index: speaker for speaker in speakers} + + parsed_messages: list[ParsedMessage] = [] + + # Find the last timestamp to determine podcast duration + last_timestamp = timedelta() + for message in reversed(messages): + lines = message.strip().split('\n') + if lines: + first_line = lines[0] + parts = first_line.split(':', 1) + if len(parts) == 2: + header = parts[0] + header_parts = header.split() + if len(header_parts) >= 2: + timestamp = header_parts[1].strip('()') + last_timestamp = parse_timestamp(timestamp) + break + + # Calculate the start time + now = datetime.now() + podcast_start_time = now - last_timestamp + + for message in messages: + lines = message.strip().split('\n') + if lines: + first_line = lines[0] + parts = first_line.split(':', 1) + if len(parts) == 2: + header, content = parts + header_parts = header.split() + if len(header_parts) >= 2: + speaker_index = int(header_parts[0]) + timestamp = header_parts[1].strip('()') + + if len(lines) > 1: + content += '\n' + '\n'.join(lines[1:]) + + delta = parse_timestamp(timestamp) + actual_time = podcast_start_time + delta + + speaker = speaker_dict.get(speaker_index) + if speaker: + speaker_name = speaker.name + role = speaker.role + else: + speaker_name = f'Unknown Speaker {speaker_index}' + role = 'Unknown' + + parsed_messages.append( + ParsedMessage( + speaker_index=speaker_index, + speaker_name=speaker_name, + role=role, + relative_timestamp=timestamp, + actual_timestamp=actual_time, + content=content.strip(), + ) + ) + + return parsed_messages def parse_podcast_messages(): - file_path = 'podcast_transcript.txt' - script_dir = os.path.dirname(__file__) - relative_path = os.path.join(script_dir, file_path) - - speakers = [ - Speaker(index=0, name='Stephen DUBNER', role='Host'), - Speaker(index=1, name='Tania Tetlow', role='Guest'), - Speaker(index=4, name='Narrator', role='Narrator'), - Speaker(index=5, name='Kamala Harris', role='Quoted'), - Speaker(index=6, name='Unknown Speaker', role='Unknown'), - Speaker(index=7, name='Unknown Speaker', role='Unknown'), - Speaker(index=8, name='Unknown Speaker', role='Unknown'), - Speaker(index=10, name='Unknown Speaker', role='Unknown'), - ] - - parsed_conversation = parse_conversation_file(relative_path, speakers) - print(f'Number of messages: {len(parsed_conversation)}') - return parsed_conversation + file_path = 'podcast_transcript.txt' + script_dir = os.path.dirname(__file__) + relative_path = os.path.join(script_dir, file_path) + + speakers = [ + Speaker(index=0, name='Stephen DUBNER', role='Host'), + Speaker(index=1, name='Tania Tetlow', role='Guest'), + Speaker(index=4, name='Narrator', role='Narrator'), + Speaker(index=5, name='Kamala Harris', role='Quoted'), + Speaker(index=6, name='Unknown Speaker', role='Unknown'), + Speaker(index=7, name='Unknown Speaker', role='Unknown'), + Speaker(index=8, name='Unknown Speaker', role='Unknown'), + Speaker(index=10, name='Unknown Speaker', role='Unknown'), + ] + + parsed_conversation = parse_conversation_file(relative_path, speakers) + print(f'Number of messages: {len(parsed_conversation)}') + return parsed_conversation diff --git a/pyproject.toml b/pyproject.toml index 17196cf..142ef38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,5 +58,5 @@ ignore = ["E501"] [tool.ruff.format] quote-style = "single" -indent-style = "tab" +indent-style = "space" docstring-code-format = true diff --git a/runner.py b/runner.py index ce150ff..a1338dc 100644 --- a/runner.py +++ b/runner.py @@ -17,67 +17,116 @@ def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO - - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) - - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - - # Add formatter to console handler - console_handler.setFormatter(formatter) - - # Add console handler to logger - logger.addHandler(console_handler) - - return logger + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO + + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) + + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + + # Add formatter to console handler + console_handler.setFormatter(formatter) + + # Add console handler to logger + logger.addHandler(console_handler) + + return logger + + +bmw_sales = [ + { + 'episode_body': 'Paul (buyer): Hi, I would like to buy a new car', + }, + { + 'episode_body': 'Dan The Salesman (salesman): Sure, I can help you with that. What kind of car are you looking for?', + }, + { + 'episode_body': 'Paul (buyer): I am looking for a new BMW', + }, + { + 'episode_body': 'Dan The Salesman (salesman): Great choice! What kind of BMW are you looking for?', + }, + { + 'episode_body': 'Paul (buyer): I am considering a BMW 3 series', + }, + { + 'episode_body': 'Dan The Salesman (salesman): Great choice, we currently have a 2024 BMW 3 series in stock, it is a great car and costs $50,000', + }, + { + 'episode_body': "Paul (buyer): Actually I am interested in something cheaper, I won't consider anything over $30,000", + }, +] + +dates_mentioned = [ + { + 'episode_body': 'Paul (user): I have graduated from Univerity of Toronto in 2022', + }, + { + 'episode_body': 'Jane (user): How cool, I graduated from the same school in 1999', + }, +] + +times_mentioned = [ + { + 'episode_body': 'Paul (user): 15 minutes ago we put a deposit on our new house', + }, +] + +time_range_mentioned = [ + { + 'episode_body': 'Paul (user): I served as a US Marine in 2015-2019', + }, +] + +relative_time_range_mentioned = [ + { + 'episode_body': 'Paul (user): I lived in Toronto for 10 years, until moving to Vancouver yesterday', + }, +] async def main(): - setup_logging() - client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) - await clear_data(client.driver) - - # await client.build_indices() - await client.add_episode( - name='Message 3', - episode_body='Jane: I am married to Paul', - source_description='WhatsApp Message', - reference_time=datetime.now(), - ) - await client.add_episode( - name='Message 4', - episode_body='Paul: I have divorced Jane', - source_description='WhatsApp Message', - reference_time=datetime.now(), - ) - await client.add_episode( - name='Message 5', - episode_body='Jane: I miss Paul', - source_description='WhatsApp Message', - reference_time=datetime.now(), - ) - await client.add_episode( - name='Message 6', - episode_body='Jane: I dont miss Paul anymore, I hate him', - source_description='WhatsApp Message', - reference_time=datetime.now(), - ) - - # await client.add_episode( - # name="Message 3", - # episode_body="Assistant: The best type of apples available are Fuji apples", - # source_description="WhatsApp Message", - # ) - # await client.add_episode( - # name="Message 4", - # episode_body="Paul: Oh, I actually hate those", - # source_description="WhatsApp Message", - # ) + setup_logging() + client = Graphiti(neo4j_uri, neo4j_user, neo4j_password) + await clear_data(client.driver) + await client.build_indices_and_constraints() + + # await client.build_indices() + for i, message in enumerate(bmw_sales): + await client.add_episode( + name=f'Message {i}', + episode_body=message['episode_body'], + source_description='', + # reference_time=datetime.now() - timedelta(days=365 * 3), + reference_time=datetime.now(), + ) + # await client.add_episode( + # name='Message 5', + # episode_body='Jane: I miss Paul', + # source_description='WhatsApp Message', + # reference_time=datetime.now(), + # ) + # await client.add_episode( + # name='Message 6', + # episode_body='Jane: I dont miss Paul anymore, I hate him', + # source_description='WhatsApp Message', + # reference_time=datetime.now(), + # ) + + # await client.add_episode( + # name="Message 3", + # episode_body="Assistant: The best type of apples available are Fuji apples", + # source_description="WhatsApp Message", + # ) + # await client.add_episode( + # name="Message 4", + # episode_body="Paul: Oh, I actually hate those", + # source_description="WhatsApp Message", + # ) asyncio.run(main()) diff --git a/tests/tests_int_graphiti.py b/tests/tests_int_graphiti.py index 65c818b..145b547 100644 --- a/tests/tests_int_graphiti.py +++ b/tests/tests_int_graphiti.py @@ -25,107 +25,107 @@ def setup_logging(): - # Create a logger - logger = logging.getLogger() - logger.setLevel(logging.INFO) # Set the logging level to INFO + # Create a logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Set the logging level to INFO - # Create console handler and set level to INFO - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(logging.INFO) + # Create console handler and set level to INFO + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.INFO) - # Create formatter - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + # Create formatter + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - # Add formatter to console handler - console_handler.setFormatter(formatter) + # Add formatter to console handler + console_handler.setFormatter(formatter) - # Add console handler to logger - logger.addHandler(console_handler) + # Add console handler to logger + logger.addHandler(console_handler) - return logger + return logger def format_context(facts): - formatted_string = '' - formatted_string += 'FACTS:\n' - for fact in facts: - formatted_string += f' - {fact}\n' - formatted_string += '\n' + formatted_string = '' + formatted_string += 'FACTS:\n' + for fact in facts: + formatted_string += f' - {fact}\n' + formatted_string += '\n' - return formatted_string.strip() + return formatted_string.strip() @pytest.mark.asyncio async def test_graphiti_init(): - logger = setup_logging() - graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) + logger = setup_logging() + graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) - facts = await graphiti.search('Freakenomics guest') + facts = await graphiti.search('Freakenomics guest') - logger.info('\nQUERY: Freakenomics guest\n' + format_context(facts)) + logger.info('\nQUERY: Freakenomics guest\n' + format_context(facts)) - facts = await graphiti.search('tania tetlow\n') + facts = await graphiti.search('tania tetlow\n') - logger.info('\nQUERY: Tania Tetlow\n' + format_context(facts)) + logger.info('\nQUERY: Tania Tetlow\n' + format_context(facts)) - facts = await graphiti.search('issues with higher ed') + facts = await graphiti.search('issues with higher ed') - logger.info('\nQUERY: issues with higher ed\n' + format_context(facts)) - graphiti.close() + logger.info('\nQUERY: issues with higher ed\n' + format_context(facts)) + graphiti.close() @pytest.mark.asyncio async def test_graph_integration(): - driver = AsyncGraphDatabase.driver( - NEO4J_URI, - auth=(NEO4j_USER, NEO4j_PASSWORD), - ) - embedder = OpenAI().embeddings - - now = datetime.now() - episode = EpisodicNode( - name='test_episode', - labels=[], - created_at=now, - source='message', - source_description='conversation message', - content='Alice likes Bob', - entity_edges=[], - ) - - alice_node = EntityNode( - name='Alice', - labels=[], - created_at=now, - summary='Alice summary', - ) - - bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary') - - episodic_edge_1 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now - ) - - episodic_edge_2 = EpisodicEdge( - source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now - ) - - entity_edge = EntityEdge( - source_node_uuid=alice_node.uuid, - target_node_uuid=bob_node.uuid, - created_at=now, - name='likes', - fact='Alice likes Bob', - episodes=[], - expired_at=now, - valid_at=now, - invalid_at=now, - ) - - entity_edge.generate_embedding(embedder) - - nodes = [episode, alice_node, bob_node] - edges = [episodic_edge_1, episodic_edge_2, entity_edge] - - await asyncio.gather(*[node.save(driver) for node in nodes]) - await asyncio.gather(*[edge.save(driver) for edge in edges]) + driver = AsyncGraphDatabase.driver( + NEO4J_URI, + auth=(NEO4j_USER, NEO4j_PASSWORD), + ) + embedder = OpenAI().embeddings + + now = datetime.now() + episode = EpisodicNode( + name='test_episode', + labels=[], + created_at=now, + source='message', + source_description='conversation message', + content='Alice likes Bob', + entity_edges=[], + ) + + alice_node = EntityNode( + name='Alice', + labels=[], + created_at=now, + summary='Alice summary', + ) + + bob_node = EntityNode(name='Bob', labels=[], created_at=now, summary='Bob summary') + + episodic_edge_1 = EpisodicEdge( + source_node_uuid=episode.uuid, target_node_uuid=alice_node.uuid, created_at=now + ) + + episodic_edge_2 = EpisodicEdge( + source_node_uuid=episode.uuid, target_node_uuid=bob_node.uuid, created_at=now + ) + + entity_edge = EntityEdge( + source_node_uuid=alice_node.uuid, + target_node_uuid=bob_node.uuid, + created_at=now, + name='likes', + fact='Alice likes Bob', + episodes=[], + expired_at=now, + valid_at=now, + invalid_at=now, + ) + + entity_edge.generate_embedding(embedder) + + nodes = [episode, alice_node, bob_node] + edges = [episodic_edge_1, episodic_edge_2, entity_edge] + + await asyncio.gather(*[node.save(driver) for node in nodes]) + await asyncio.gather(*[edge.save(driver) for edge in edges]) diff --git a/tests/utils/maintenance/test_temporal_operations.py b/tests/utils/maintenance/test_temporal_operations.py index 7a3b94c..88e6b7c 100644 --- a/tests/utils/maintenance/test_temporal_operations.py +++ b/tests/utils/maintenance/test_temporal_operations.py @@ -5,294 +5,294 @@ from core.edges import EntityEdge from core.nodes import EntityNode, EpisodicNode from core.utils.maintenance.temporal_operations import ( - prepare_edges_for_invalidation, - prepare_invalidation_context, + prepare_edges_for_invalidation, + prepare_invalidation_context, ) # Helper function to create test data def create_test_data(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) - - # Create edges - existing_edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='KNOWS', - fact='Node1 knows Node2', - created_at=now, - ) - existing_edge2 = EntityEdge( - uuid='e2', - source_node_uuid='2', - target_node_uuid='3', - name='LIKES', - fact='Node2 likes Node3', - created_at=now, - ) - new_edge1 = EntityEdge( - uuid='e3', - source_node_uuid='1', - target_node_uuid='3', - name='WORKS_WITH', - fact='Node1 works with Node3', - created_at=now, - ) - new_edge2 = EntityEdge( - uuid='e4', - source_node_uuid='1', - target_node_uuid='2', - name='DISLIKES', - fact='Node1 dislikes Node2', - created_at=now, - ) - - return { - 'nodes': [node1, node2, node3], - 'existing_edges': [existing_edge1, existing_edge2], - 'new_edges': [new_edge1, new_edge2], - } + now = datetime.now() + + # Create nodes + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) + node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) + + # Create edges + existing_edge1 = EntityEdge( + uuid='e1', + source_node_uuid='1', + target_node_uuid='2', + name='KNOWS', + fact='Node1 knows Node2', + created_at=now, + ) + existing_edge2 = EntityEdge( + uuid='e2', + source_node_uuid='2', + target_node_uuid='3', + name='LIKES', + fact='Node2 likes Node3', + created_at=now, + ) + new_edge1 = EntityEdge( + uuid='e3', + source_node_uuid='1', + target_node_uuid='3', + name='WORKS_WITH', + fact='Node1 works with Node3', + created_at=now, + ) + new_edge2 = EntityEdge( + uuid='e4', + source_node_uuid='1', + target_node_uuid='2', + name='DISLIKES', + fact='Node1 dislikes Node2', + created_at=now, + ) + + return { + 'nodes': [node1, node2, node3], + 'existing_edges': [existing_edge1, existing_edge2], + 'new_edges': [new_edge1, new_edge2], + } def test_prepare_edges_for_invalidation_basic(): - test_data = create_test_data() + test_data = create_test_data() - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - test_data['existing_edges'], test_data['new_edges'], test_data['nodes'] - ) + existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( + test_data['existing_edges'], test_data['new_edges'], test_data['nodes'] + ) - assert len(existing_edges_pending_invalidation) == 2 - assert len(new_edges_with_nodes) == 2 + assert len(existing_edges_pending_invalidation) == 2 + assert len(new_edges_with_nodes) == 2 - # Check if the edges are correctly associated with nodes - for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes: - assert isinstance(edge_with_nodes[0], EntityNode) - assert isinstance(edge_with_nodes[1], EntityEdge) - assert isinstance(edge_with_nodes[2], EntityNode) + # Check if the edges are correctly associated with nodes + for edge_with_nodes in existing_edges_pending_invalidation + new_edges_with_nodes: + assert isinstance(edge_with_nodes[0], EntityNode) + assert isinstance(edge_with_nodes[1], EntityEdge) + assert isinstance(edge_with_nodes[2], EntityNode) def test_prepare_edges_for_invalidation_no_existing_edges(): - test_data = create_test_data() + test_data = create_test_data() - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - [], test_data['new_edges'], test_data['nodes'] - ) + existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( + [], test_data['new_edges'], test_data['nodes'] + ) - assert len(existing_edges_pending_invalidation) == 0 - assert len(new_edges_with_nodes) == 2 + assert len(existing_edges_pending_invalidation) == 0 + assert len(new_edges_with_nodes) == 2 def test_prepare_edges_for_invalidation_no_new_edges(): - test_data = create_test_data() + test_data = create_test_data() - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - test_data['existing_edges'], [], test_data['nodes'] - ) + existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( + test_data['existing_edges'], [], test_data['nodes'] + ) - assert len(existing_edges_pending_invalidation) == 2 - assert len(new_edges_with_nodes) == 0 + assert len(existing_edges_pending_invalidation) == 2 + assert len(new_edges_with_nodes) == 0 def test_prepare_edges_for_invalidation_missing_nodes(): - test_data = create_test_data() + test_data = create_test_data() - # Remove one node to simulate a missing node scenario - nodes = test_data['nodes'][:-1] + # Remove one node to simulate a missing node scenario + nodes = test_data['nodes'][:-1] - existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( - test_data['existing_edges'], test_data['new_edges'], nodes - ) + existing_edges_pending_invalidation, new_edges_with_nodes = prepare_edges_for_invalidation( + test_data['existing_edges'], test_data['new_edges'], nodes + ) - assert len(existing_edges_pending_invalidation) == 1 - assert len(new_edges_with_nodes) == 1 + assert len(existing_edges_pending_invalidation) == 1 + assert len(new_edges_with_nodes) == 1 def test_prepare_invalidation_context(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) - - # Create edges - edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='KNOWS', - fact='Node1 knows Node2', - created_at=now, - ) - edge2 = EntityEdge( - uuid='e2', - source_node_uuid='2', - target_node_uuid='3', - name='LIKES', - fact='Node2 likes Node3', - created_at=now, - ) - - # Create NodeEdgeNodeTriplet objects - existing_edge = (node1, edge1, node2) - new_edge = (node2, edge2, node3) - - # Prepare test input - existing_edges = [existing_edge] - new_edges = [new_edge] - - # Create a current episode and previous episodes - current_episode = EpisodicNode( - name='Current Episode', - content='This is the current episode content.', - created_at=now, - valid_at=now, - source='test', - source_description='Test episode for unit testing', - ) - previous_episodes = [ - EpisodicNode( - name='Previous Episode 1', - content='This is the content of previous episode 1.', - created_at=now - timedelta(days=1), - valid_at=now - timedelta(days=1), - source='test', - source_description='Test previous episode 1 for unit testing', - ), - EpisodicNode( - name='Previous Episode 2', - content='This is the content of previous episode 2.', - created_at=now - timedelta(days=2), - valid_at=now - timedelta(days=2), - source='test', - source_description='Test previous episode 2 for unit testing', - ), - ] - - # Call the function - result = prepare_invalidation_context( - existing_edges, new_edges, current_episode, previous_episodes - ) - - # Assert the result - assert isinstance(result, dict) - assert 'existing_edges' in result - assert 'new_edges' in result - assert 'current_episode' in result - assert 'previous_episodes' in result - assert len(result['existing_edges']) == 1 - assert len(result['new_edges']) == 1 - assert result['current_episode'] == current_episode.content - assert len(result['previous_episodes']) == 2 - - # Check the format of the existing edge - existing_edge_str = result['existing_edges'][0] - assert edge1.uuid in existing_edge_str - assert node1.name in existing_edge_str - assert edge1.name in existing_edge_str - assert node2.name in existing_edge_str - assert edge1.created_at.isoformat() in existing_edge_str - - # Check the format of the new edge - new_edge_str = result['new_edges'][0] - assert edge2.uuid in new_edge_str - assert node2.name in new_edge_str - assert edge2.name in new_edge_str - assert node3.name in new_edge_str - assert edge2.created_at.isoformat() in new_edge_str + now = datetime.now() + + # Create nodes + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) + node3 = EntityNode(uuid='3', name='Node3', labels=['Person'], created_at=now) + + # Create edges + edge1 = EntityEdge( + uuid='e1', + source_node_uuid='1', + target_node_uuid='2', + name='KNOWS', + fact='Node1 knows Node2', + created_at=now, + ) + edge2 = EntityEdge( + uuid='e2', + source_node_uuid='2', + target_node_uuid='3', + name='LIKES', + fact='Node2 likes Node3', + created_at=now, + ) + + # Create NodeEdgeNodeTriplet objects + existing_edge = (node1, edge1, node2) + new_edge = (node2, edge2, node3) + + # Prepare test input + existing_edges = [existing_edge] + new_edges = [new_edge] + + # Create a current episode and previous episodes + current_episode = EpisodicNode( + name='Current Episode', + content='This is the current episode content.', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + previous_episodes = [ + EpisodicNode( + name='Previous Episode 1', + content='This is the content of previous episode 1.', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode 1 for unit testing', + ), + EpisodicNode( + name='Previous Episode 2', + content='This is the content of previous episode 2.', + created_at=now - timedelta(days=2), + valid_at=now - timedelta(days=2), + source='test', + source_description='Test previous episode 2 for unit testing', + ), + ] + + # Call the function + result = prepare_invalidation_context( + existing_edges, new_edges, current_episode, previous_episodes + ) + + # Assert the result + assert isinstance(result, dict) + assert 'existing_edges' in result + assert 'new_edges' in result + assert 'current_episode' in result + assert 'previous_episodes' in result + assert len(result['existing_edges']) == 1 + assert len(result['new_edges']) == 1 + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 2 + + # Check the format of the existing edge + existing_edge_str = result['existing_edges'][0] + assert edge1.uuid in existing_edge_str + assert node1.name in existing_edge_str + assert edge1.name in existing_edge_str + assert node2.name in existing_edge_str + assert edge1.created_at.isoformat() in existing_edge_str + + # Check the format of the new edge + new_edge_str = result['new_edges'][0] + assert edge2.uuid in new_edge_str + assert node2.name in new_edge_str + assert edge2.name in new_edge_str + assert node3.name in new_edge_str + assert edge2.created_at.isoformat() in new_edge_str def test_prepare_invalidation_context_empty_input(): - now = datetime.now() - current_episode = EpisodicNode( - name='Current Episode', - content='Empty episode', - created_at=now, - valid_at=now, - source='test', - source_description='Test empty episode for unit testing', - ) - result = prepare_invalidation_context([], [], current_episode, []) - assert isinstance(result, dict) - assert 'existing_edges' in result - assert 'new_edges' in result - assert 'current_episode' in result - assert 'previous_episodes' in result - assert len(result['existing_edges']) == 0 - assert len(result['new_edges']) == 0 - assert result['current_episode'] == current_episode.content - assert len(result['previous_episodes']) == 0 + now = datetime.now() + current_episode = EpisodicNode( + name='Current Episode', + content='Empty episode', + created_at=now, + valid_at=now, + source='test', + source_description='Test empty episode for unit testing', + ) + result = prepare_invalidation_context([], [], current_episode, []) + assert isinstance(result, dict) + assert 'existing_edges' in result + assert 'new_edges' in result + assert 'current_episode' in result + assert 'previous_episodes' in result + assert len(result['existing_edges']) == 0 + assert len(result['new_edges']) == 0 + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 0 def test_prepare_invalidation_context_sorting(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) - - # Create edges with different timestamps - edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='KNOWS', - fact='Node1 knows Node2', - created_at=now, - ) - edge2 = EntityEdge( - uuid='e2', - source_node_uuid='2', - target_node_uuid='1', - name='LIKES', - fact='Node2 likes Node1', - created_at=now + timedelta(hours=1), - ) - - edge_with_nodes1 = (node1, edge1, node2) - edge_with_nodes2 = (node2, edge2, node1) - - # Prepare test input - existing_edges = [edge_with_nodes1, edge_with_nodes2] - - # Create a current episode and previous episodes - current_episode = EpisodicNode( - name='Current Episode', - content='This is the current episode content.', - created_at=now, - valid_at=now, - source='test', - source_description='Test episode for unit testing', - ) - previous_episodes = [ - EpisodicNode( - name='Previous Episode', - content='This is the content of a previous episode.', - created_at=now - timedelta(days=1), - valid_at=now - timedelta(days=1), - source='test', - source_description='Test previous episode for unit testing', - ), - ] - - # Call the function - result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes) - - # Assert the result - assert len(result['existing_edges']) == 2 - assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first - assert edge1.uuid in result['existing_edges'][1] # The older edge should be second - assert result['current_episode'] == current_episode.content - assert len(result['previous_episodes']) == 1 - assert result['previous_episodes'][0] == previous_episodes[0].content + now = datetime.now() + + # Create nodes + node1 = EntityNode(uuid='1', name='Node1', labels=['Person'], created_at=now) + node2 = EntityNode(uuid='2', name='Node2', labels=['Person'], created_at=now) + + # Create edges with different timestamps + edge1 = EntityEdge( + uuid='e1', + source_node_uuid='1', + target_node_uuid='2', + name='KNOWS', + fact='Node1 knows Node2', + created_at=now, + ) + edge2 = EntityEdge( + uuid='e2', + source_node_uuid='2', + target_node_uuid='1', + name='LIKES', + fact='Node2 likes Node1', + created_at=now + timedelta(hours=1), + ) + + edge_with_nodes1 = (node1, edge1, node2) + edge_with_nodes2 = (node2, edge2, node1) + + # Prepare test input + existing_edges = [edge_with_nodes1, edge_with_nodes2] + + # Create a current episode and previous episodes + current_episode = EpisodicNode( + name='Current Episode', + content='This is the current episode content.', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + previous_episodes = [ + EpisodicNode( + name='Previous Episode', + content='This is the content of a previous episode.', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode for unit testing', + ), + ] + + # Call the function + result = prepare_invalidation_context(existing_edges, [], current_episode, previous_episodes) + + # Assert the result + assert len(result['existing_edges']) == 2 + assert edge2.uuid in result['existing_edges'][0] # The newer edge should be first + assert edge1.uuid in result['existing_edges'][1] # The older edge should be second + assert result['current_episode'] == current_episode.content + assert len(result['previous_episodes']) == 1 + assert result['previous_episodes'][0] == previous_episodes[0].content # Run the tests if __name__ == '__main__': - pytest.main([__file__]) + pytest.main([__file__]) diff --git a/tests/utils/maintenance/test_temporal_operations_int.py b/tests/utils/maintenance/test_temporal_operations_int.py index 37baf0c..419e1bd 100644 --- a/tests/utils/maintenance/test_temporal_operations_int.py +++ b/tests/utils/maintenance/test_temporal_operations_int.py @@ -8,311 +8,311 @@ from core.llm_client import LLMConfig, OpenAIClient from core.nodes import EntityNode, EpisodicNode from core.utils.maintenance.temporal_operations import ( - invalidate_edges, + invalidate_edges, ) load_dotenv() def setup_llm_client(): - return OpenAIClient( - LLMConfig( - api_key=os.getenv('TEST_OPENAI_API_KEY'), - model=os.getenv('TEST_OPENAI_MODEL'), - base_url='https://api.openai.com/v1', - ) - ) + return OpenAIClient( + LLMConfig( + api_key=os.getenv('TEST_OPENAI_API_KEY'), + model=os.getenv('TEST_OPENAI_MODEL'), + base_url='https://api.openai.com/v1', + ) + ) def create_test_data(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) - - # Create edges - edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='LIKES', - fact='Alice likes Bob', - created_at=now - timedelta(days=1), - ) - edge2 = EntityEdge( - uuid='e2', - source_node_uuid='1', - target_node_uuid='2', - name='DISLIKES', - fact='Alice dislikes Bob', - created_at=now, - ) - - existing_edge = (node1, edge1, node2) - new_edge = (node1, edge2, node2) - - # Create current episode - current_episode = EpisodicNode( - name='Current Episode', - content='Alice now dislikes Bob', - created_at=now, - valid_at=now, - source='test', - source_description='Test episode for unit testing', - ) - - # Create previous episodes - previous_episodes = [ - EpisodicNode( - name='Previous Episode', - content='Alice liked Bob', - created_at=now - timedelta(days=1), - valid_at=now - timedelta(days=1), - source='test', - source_description='Test previous episode for unit testing', - ) - ] - - return existing_edge, new_edge, current_episode, previous_episodes + now = datetime.now() + + # Create nodes + node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) + node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) + + # Create edges + edge1 = EntityEdge( + uuid='e1', + source_node_uuid='1', + target_node_uuid='2', + name='LIKES', + fact='Alice likes Bob', + created_at=now - timedelta(days=1), + ) + edge2 = EntityEdge( + uuid='e2', + source_node_uuid='1', + target_node_uuid='2', + name='DISLIKES', + fact='Alice dislikes Bob', + created_at=now, + ) + + existing_edge = (node1, edge1, node2) + new_edge = (node1, edge2, node2) + + # Create current episode + current_episode = EpisodicNode( + name='Current Episode', + content='Alice now dislikes Bob', + created_at=now, + valid_at=now, + source='test', + source_description='Test episode for unit testing', + ) + + # Create previous episodes + previous_episodes = [ + EpisodicNode( + name='Previous Episode', + content='Alice liked Bob', + created_at=now - timedelta(days=1), + valid_at=now - timedelta(days=1), + source='test', + source_description='Test previous episode for unit testing', + ) + ] + + return existing_edge, new_edge, current_episode, previous_episodes @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges(): - existing_edge, new_edge, current_episode, previous_episodes = create_test_data() + existing_edge, new_edge, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges( - setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes - ) + invalidated_edges = await invalidate_edges( + setup_llm_client(), [existing_edge], [new_edge], current_episode, previous_episodes + ) - assert len(invalidated_edges) == 1 - assert invalidated_edges[0].uuid == existing_edge[1].uuid - assert invalidated_edges[0].expired_at is not None + assert len(invalidated_edges) == 1 + assert invalidated_edges[0].uuid == existing_edge[1].uuid + assert invalidated_edges[0].expired_at is not None @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_no_invalidation(): - existing_edge, _, current_episode, previous_episodes = create_test_data() + existing_edge, _, current_episode, previous_episodes = create_test_data() - invalidated_edges = await invalidate_edges( - setup_llm_client(), [existing_edge], [], current_episode, previous_episodes - ) + invalidated_edges = await invalidate_edges( + setup_llm_client(), [existing_edge], [], current_episode, previous_episodes + ) - assert len(invalidated_edges) == 0 + assert len(invalidated_edges) == 0 @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_multiple_existing(): - existing_edge1, new_edge = create_test_data() - existing_edge2, _ = create_test_data() - existing_edge2[1].uuid = 'e3' - existing_edge2[1].name = 'KNOWS' - existing_edge2[1].fact = 'Alice knows Bob' + existing_edge1, new_edge = create_test_data() + existing_edge2, _ = create_test_data() + existing_edge2[1].uuid = 'e3' + existing_edge2[1].name = 'KNOWS' + existing_edge2[1].fact = 'Alice knows Bob' - invalidated_edges = await invalidate_edges( - setup_llm_client(), [existing_edge1, existing_edge2], [new_edge] - ) + invalidated_edges = await invalidate_edges( + setup_llm_client(), [existing_edge1, existing_edge2], [new_edge] + ) - assert len(invalidated_edges) == 1 - assert invalidated_edges[0].uuid == existing_edge1[1].uuid - assert invalidated_edges[0].expired_at is not None + assert len(invalidated_edges) == 1 + assert invalidated_edges[0].uuid == existing_edge1[1].uuid + assert invalidated_edges[0].expired_at is not None # Helper function to create more complex test data def create_complex_test_data(): - now = datetime.now() - - # Create nodes - node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) - node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) - node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now) - node4 = EntityNode(uuid='4', name='Company XYZ', labels=['Organization'], created_at=now) - - # Create edges - edge1 = EntityEdge( - uuid='e1', - source_node_uuid='1', - target_node_uuid='2', - name='LIKES', - fact='Alice likes Bob', - created_at=now - timedelta(days=5), - ) - edge2 = EntityEdge( - uuid='e2', - source_node_uuid='1', - target_node_uuid='3', - name='FRIENDS_WITH', - fact='Alice is friends with Charlie', - created_at=now - timedelta(days=3), - ) - edge3 = EntityEdge( - uuid='e3', - source_node_uuid='2', - target_node_uuid='4', - name='WORKS_FOR', - fact='Bob works for Company XYZ', - created_at=now - timedelta(days=2), - ) - - existing_edge1 = (node1, edge1, node2) - existing_edge2 = (node1, edge2, node3) - existing_edge3 = (node2, edge3, node4) - - return [existing_edge1, existing_edge2, existing_edge3], [ - node1, - node2, - node3, - node4, - ] + now = datetime.now() + + # Create nodes + node1 = EntityNode(uuid='1', name='Alice', labels=['Person'], created_at=now) + node2 = EntityNode(uuid='2', name='Bob', labels=['Person'], created_at=now) + node3 = EntityNode(uuid='3', name='Charlie', labels=['Person'], created_at=now) + node4 = EntityNode(uuid='4', name='Company XYZ', labels=['Organization'], created_at=now) + + # Create edges + edge1 = EntityEdge( + uuid='e1', + source_node_uuid='1', + target_node_uuid='2', + name='LIKES', + fact='Alice likes Bob', + created_at=now - timedelta(days=5), + ) + edge2 = EntityEdge( + uuid='e2', + source_node_uuid='1', + target_node_uuid='3', + name='FRIENDS_WITH', + fact='Alice is friends with Charlie', + created_at=now - timedelta(days=3), + ) + edge3 = EntityEdge( + uuid='e3', + source_node_uuid='2', + target_node_uuid='4', + name='WORKS_FOR', + fact='Bob works for Company XYZ', + created_at=now - timedelta(days=2), + ) + + existing_edge1 = (node1, edge1, node2) + existing_edge2 = (node1, edge2, node3) + existing_edge3 = (node2, edge3, node4) + + return [existing_edge1, existing_edge2, existing_edge3], [ + node1, + node2, + node3, + node4, + ] @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_complex(): - existing_edges, nodes = create_complex_test_data() + existing_edges, nodes = create_complex_test_data() - # Create a new edge that contradicts an existing one - new_edge = ( - nodes[0], - EntityEdge( - uuid='e4', - source_node_uuid='1', - target_node_uuid='2', - name='DISLIKES', - fact='Alice dislikes Bob', - created_at=datetime.now(), - ), - nodes[1], - ) + # Create a new edge that contradicts an existing one + new_edge = ( + nodes[0], + EntityEdge( + uuid='e4', + source_node_uuid='1', + target_node_uuid='2', + name='DISLIKES', + fact='Alice dislikes Bob', + created_at=datetime.now(), + ), + nodes[1], + ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) - assert len(invalidated_edges) == 1 - assert invalidated_edges[0].uuid == 'e1' - assert invalidated_edges[0].expired_at is not None + assert len(invalidated_edges) == 1 + assert invalidated_edges[0].uuid == 'e1' + assert invalidated_edges[0].expired_at is not None @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_temporal_update(): - existing_edges, nodes = create_complex_test_data() + existing_edges, nodes = create_complex_test_data() - # Create a new edge that updates an existing one with new information - new_edge = ( - nodes[1], - EntityEdge( - uuid='e5', - source_node_uuid='2', - target_node_uuid='4', - name='LEFT_JOB', - fact='Bob left his job at Company XYZ', - created_at=datetime.now(), - ), - nodes[3], - ) + # Create a new edge that updates an existing one with new information + new_edge = ( + nodes[1], + EntityEdge( + uuid='e5', + source_node_uuid='2', + target_node_uuid='4', + name='LEFT_JOB', + fact='Bob left his job at Company XYZ', + created_at=datetime.now(), + ), + nodes[3], + ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) - assert len(invalidated_edges) == 1 - assert invalidated_edges[0].uuid == 'e3' - assert invalidated_edges[0].expired_at is not None + assert len(invalidated_edges) == 1 + assert invalidated_edges[0].uuid == 'e3' + assert invalidated_edges[0].expired_at is not None @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_multiple_invalidations(): - existing_edges, nodes = create_complex_test_data() - - # Create new edges that invalidate multiple existing edges - new_edge1 = ( - nodes[0], - EntityEdge( - uuid='e6', - source_node_uuid='1', - target_node_uuid='2', - name='ENEMIES_WITH', - fact='Alice and Bob are now enemies', - created_at=datetime.now(), - ), - nodes[1], - ) - new_edge2 = ( - nodes[0], - EntityEdge( - uuid='e7', - source_node_uuid='1', - target_node_uuid='3', - name='ENDED_FRIENDSHIP', - fact='Alice ended her friendship with Charlie', - created_at=datetime.now(), - ), - nodes[2], - ) - - invalidated_edges = await invalidate_edges( - setup_llm_client(), existing_edges, [new_edge1, new_edge2] - ) - - assert len(invalidated_edges) == 2 - assert set(edge.uuid for edge in invalidated_edges) == {'e1', 'e2'} - for edge in invalidated_edges: - assert edge.expired_at is not None + existing_edges, nodes = create_complex_test_data() + + # Create new edges that invalidate multiple existing edges + new_edge1 = ( + nodes[0], + EntityEdge( + uuid='e6', + source_node_uuid='1', + target_node_uuid='2', + name='ENEMIES_WITH', + fact='Alice and Bob are now enemies', + created_at=datetime.now(), + ), + nodes[1], + ) + new_edge2 = ( + nodes[0], + EntityEdge( + uuid='e7', + source_node_uuid='1', + target_node_uuid='3', + name='ENDED_FRIENDSHIP', + fact='Alice ended her friendship with Charlie', + created_at=datetime.now(), + ), + nodes[2], + ) + + invalidated_edges = await invalidate_edges( + setup_llm_client(), existing_edges, [new_edge1, new_edge2] + ) + + assert len(invalidated_edges) == 2 + assert set(edge.uuid for edge in invalidated_edges) == {'e1', 'e2'} + for edge in invalidated_edges: + assert edge.expired_at is not None @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_no_effect(): - existing_edges, nodes = create_complex_test_data() + existing_edges, nodes = create_complex_test_data() - # Create a new edge that doesn't invalidate any existing edges - new_edge = ( - nodes[2], - EntityEdge( - uuid='e8', - source_node_uuid='3', - target_node_uuid='4', - name='APPLIED_TO', - fact='Charlie applied to Company XYZ', - created_at=datetime.now(), - ), - nodes[3], - ) + # Create a new edge that doesn't invalidate any existing edges + new_edge = ( + nodes[2], + EntityEdge( + uuid='e8', + source_node_uuid='3', + target_node_uuid='4', + name='APPLIED_TO', + fact='Charlie applied to Company XYZ', + created_at=datetime.now(), + ), + nodes[3], + ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) - assert len(invalidated_edges) == 0 + assert len(invalidated_edges) == 0 @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_partial_update(): - existing_edges, nodes = create_complex_test_data() + existing_edges, nodes = create_complex_test_data() - # Create a new edge that partially updates an existing one - new_edge = ( - nodes[1], - EntityEdge( - uuid='e9', - source_node_uuid='2', - target_node_uuid='4', - name='CHANGED_POSITION', - fact='Bob changed his position at Company XYZ', - created_at=datetime.now(), - ), - nodes[3], - ) + # Create a new edge that partially updates an existing one + new_edge = ( + nodes[1], + EntityEdge( + uuid='e9', + source_node_uuid='2', + target_node_uuid='4', + name='CHANGED_POSITION', + fact='Bob changed his position at Company XYZ', + created_at=datetime.now(), + ), + nodes[3], + ) - invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) + invalidated_edges = await invalidate_edges(setup_llm_client(), existing_edges, [new_edge]) - assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated + assert len(invalidated_edges) == 0 # The existing edge is not invalidated, just updated @pytest.mark.asyncio @pytest.mark.integration async def test_invalidate_edges_empty_inputs(): - invalidated_edges = await invalidate_edges(setup_llm_client(), [], []) + invalidated_edges = await invalidate_edges(setup_llm_client(), [], []) - assert len(invalidated_edges) == 0 + assert len(invalidated_edges) == 0