diff --git a/examples/podcast/podcast_runner.py b/examples/podcast/podcast_runner.py index ad49a71c..22e1d90b 100644 --- a/examples/podcast/podcast_runner.py +++ b/examples/podcast/podcast_runner.py @@ -83,6 +83,7 @@ async def main(use_bulk: bool = True): reference_time=message.actual_timestamp, source_description='Podcast Transcript', group_id='1', + update_communities=True, ) return diff --git a/graphiti_core/edges.py b/graphiti_core/edges.py index 8459842b..a8d6f8d9 100644 --- a/graphiti_core/edges.py +++ b/graphiti_core/edges.py @@ -109,13 +109,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): raise EdgeNotFoundError(uuid) return edges[0] + @classmethod + async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Episodic)-[e:MENTIONS]->(m:Entity) + WHERE e.uuid IN $uuids + RETURN + e.uuid As uuid, + e.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at + """, + uuids=uuids, + ) + + edges = [get_episodic_edge_from_record(record) for record in records] + + logger.info(f'Found Edges: {uuids}') + if len(edges) == 0: + raise EdgeNotFoundError(uuids[0]) + return edges + class EntityEdge(Edge): name: str = Field(description='name of the edge, relation name') fact: str = Field(description='fact representing the edge and nodes that it connects') fact_embedding: list[float] | None = Field(default=None, description='embedding of the fact') - episodes: list[str] | None = Field( - default=None, + episodes: list[str] = Field( + default=[], description='list of episode ids that reference these entity edges', ) expired_at: datetime | None = Field( @@ -197,6 +220,36 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): raise EdgeNotFoundError(uuid) return edges[0] + @classmethod + async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Entity)-[e:RELATES_TO]->(m:Entity) + WHERE e.uuid IN $uuids + RETURN + e.uuid AS uuid, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at, + e.name AS name, + e.group_id AS group_id, + e.fact AS fact, + e.fact_embedding AS fact_embedding, + e.episodes AS episodes, + e.expired_at AS expired_at, + e.valid_at AS valid_at, + e.invalid_at AS invalid_at + """, + uuids=uuids, + ) + + edges = [get_entity_edge_from_record(record) for record in records] + + logger.info(f'Found Edges: {uuids}') + if len(edges) == 0: + raise EdgeNotFoundError(uuids[0]) + return edges + class CommunityEdge(Edge): async def save(self, driver: AsyncDriver): @@ -239,6 +292,28 @@ async def get_by_uuid(cls, driver: AsyncDriver, uuid: str): return edges[0] + @classmethod + async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): + records, _, _ = await driver.execute_query( + """ + MATCH (n:Community)-[e:HAS_MEMBER]->(m:Entity | Community) + WHERE e.uuid IN $uuids + RETURN + e.uuid As uuid, + e.group_id AS group_id, + n.uuid AS source_node_uuid, + m.uuid AS target_node_uuid, + e.created_at AS created_at + """, + uuids=uuids, + ) + + edges = [get_community_edge_from_record(record) for record in records] + + logger.info(f'Found Edges: {uuids}') + + return edges + # Edge helpers def get_episodic_edge_from_record(record: Any) -> EpisodicEdge: diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index f029ce48..7cf9719a 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -35,6 +35,8 @@ ) from graphiti_core.search.search_utils import ( RELEVANT_SCHEMA_LIMIT, + get_communities_by_nodes, + get_mentioned_nodes, get_relevant_edges, get_relevant_nodes, ) @@ -249,8 +251,6 @@ async def add_episode( An id for the graph partition the episode is a part of. uuid : str | None Optional uuid of the episode. - update_communities: bool - Optional. Determines if we should update communities Returns ------- @@ -413,6 +413,8 @@ async def add_episode_endpoint(episode_data: EpisodeData): logger.info(f'Built episodic edges: {episodic_edges}') + episode.entity_edges = [edge.uuid for edge in entity_edges] + # Future optimization would be using batch operations to save nodes and edges await episode.save(self.driver) await asyncio.gather(*[node.save(self.driver) for node in nodes]) @@ -680,3 +682,19 @@ async def get_nodes_by_query( await search(self.driver, embedder, query, group_ids, search_config, center_node_uuid) ).nodes return nodes + + +async def get_episode_mentions(self, episode_uuids: list[str]) -> SearchResults: + episodes = await EpisodicNode.get_by_uuids(self.driver, episode_uuids) + + edges_list = await asyncio.gather( + *[EntityEdge.get_by_uuids(self.driver, episode.entity_edges) for episode in episodes] + ) + + edges: list[EntityEdge] = [edge for lst in edges_list for edge in lst] + + nodes = await get_mentioned_nodes(self.driver, episodes) + + communities = await get_communities_by_nodes(self.driver, nodes) + + return SearchResults(edges=edges, nodes=nodes, communities=communities) diff --git a/graphiti_core/nodes.py b/graphiti_core/nodes.py index ccbbb33f..769cfe5e 100644 --- a/graphiti_core/nodes.py +++ b/graphiti_core/nodes.py @@ -170,7 +170,8 @@ async def get_by_uuids(cls, driver: AsyncDriver, uuids: list[str]): records, _, _ = await driver.execute_query( """ MATCH (e:Episodic) WHERE e.uuid IN $uuids - RETURN e.content AS content, + RETURN DISTINCT + e.content AS content, e.created_at AS created_at, e.valid_at AS valid_at, e.uuid AS uuid, diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index c1a8979f..a8b1c9f9 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -42,6 +42,7 @@ community_similarity_search, edge_fulltext_search, edge_similarity_search, + episode_mentions_reranker, node_distance_reranker, node_fulltext_search, node_similarity_search, @@ -131,7 +132,7 @@ async def edge_search( edge_uuid_map = {edge.uuid: edge for result in search_results for edge in result} reranked_uuids: list[str] = [] - if config.reranker == EdgeReranker.rrf: + if config.reranker == EdgeReranker.rrf or config.reranker == EdgeReranker.episode_mentions: search_result_uuids = [[edge.uuid for edge in result] for result in search_results] reranked_uuids = rrf(search_result_uuids) @@ -150,6 +151,9 @@ async def edge_search( reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] + if config.reranker == EdgeReranker.episode_mentions: + reranked_edges.sort(reverse=True, key=lambda edge: len(edge.episodes)) + return reranked_edges @@ -189,6 +193,8 @@ async def node_search( reranked_uuids: list[str] = [] if config.reranker == NodeReranker.rrf: reranked_uuids = rrf(search_result_uuids) + elif config.reranker == NodeReranker.episode_mentions: + reranked_uuids = await episode_mentions_reranker(driver, search_result_uuids) elif config.reranker == NodeReranker.node_distance: if center_node_uuid is None: raise SearchRerankerError('No center node provided for Node Distance reranker') diff --git a/graphiti_core/search/search_config.py b/graphiti_core/search/search_config.py index 3bd6b6cb..ceb644b9 100644 --- a/graphiti_core/search/search_config.py +++ b/graphiti_core/search/search_config.py @@ -42,11 +42,13 @@ class CommunitySearchMethod(Enum): class EdgeReranker(Enum): rrf = 'reciprocal_rank_fusion' node_distance = 'node_distance' + episode_mentions = 'episode_mentions' class NodeReranker(Enum): rrf = 'reciprocal_rank_fusion' node_distance = 'node_distance' + episode_mentions = 'episode_mentions' class CommunityReranker(Enum): diff --git a/graphiti_core/search/search_config_recipes.py b/graphiti_core/search/search_config_recipes.py index 5aa30198..8396307b 100644 --- a/graphiti_core/search/search_config_recipes.py +++ b/graphiti_core/search/search_config_recipes.py @@ -59,6 +59,14 @@ ) ) +# performs a hybrid search over edges with episode mention reranking +EDGE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig( + edge_config=EdgeSearchConfig( + search_methods=[EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity], + reranker=EdgeReranker.episode_mentions, + ) +) + # performs a hybrid search over nodes with rrf reranking NODE_HYBRID_SEARCH_RRF = SearchConfig( node_config=NodeSearchConfig( @@ -75,6 +83,14 @@ ) ) +# performs a hybrid search over nodes with episode mentions reranking +NODE_HYBRID_SEARCH_EPISODE_MENTIONS = SearchConfig( + node_config=NodeSearchConfig( + search_methods=[NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity], + reranker=NodeReranker.episode_mentions, + ) +) + # performs a hybrid search over communities with rrf reranking COMMUNITY_HYBRID_SEARCH_RRF = SearchConfig( community_config=CommunitySearchConfig( diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index ef17236c..0cc19be5 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -36,7 +36,9 @@ RELEVANT_SCHEMA_LIMIT = 3 -async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]): +async def get_mentioned_nodes( + driver: AsyncDriver, episodes: list[EpisodicNode] +) -> list[EntityNode]: episode_uuids = [episode.uuid for episode in episodes] records, _, _ = await driver.execute_query( """ @@ -57,6 +59,29 @@ async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]) return nodes +async def get_communities_by_nodes( + driver: AsyncDriver, nodes: list[EntityNode] +) -> list[CommunityNode]: + node_uuids = [node.uuid for node in nodes] + records, _, _ = await driver.execute_query( + """ + MATCH (c:Community)-[:HAS_MEMBER]->(n:Entity) WHERE n.uuid IN $uuids + RETURN DISTINCT + c.uuid As uuid, + c.group_id AS group_id, + c.name AS name, + c.name_embedding AS name_embedding + c.created_at AS created_at, + c.summary AS summary + """, + uuids=node_uuids, + ) + + communities = [get_community_node_from_record(record) for record in records] + + return communities + + async def edge_fulltext_search( driver: AsyncDriver, query: str, @@ -634,3 +659,34 @@ async def node_distance_reranker( sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) return sorted_uuids + + +async def episode_mentions_reranker(driver: AsyncDriver, node_uuids: list[list[str]]) -> list[str]: + # use rrf as a preliminary ranker + sorted_uuids = rrf(node_uuids) + scores: dict[str, float] = {} + + # Find the shortest path to center node + query = Query(""" + MATCH (episode:Episodic)-[r:MENTIONS]->(n:Entity {uuid: $node_uuid}) + RETURN count(*) AS score + """) + + result_scores = await asyncio.gather( + *[ + driver.execute_query( + query, + node_uuid=uuid, + ) + for uuid in sorted_uuids + ] + ) + + for uuid, result in zip(sorted_uuids, result_scores): + record = result[0][0] + scores[uuid] = record['score'] + + # rerank on shortest distance + sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid]) + + return sorted_uuids diff --git a/graphiti_core/utils/maintenance/edge_operations.py b/graphiti_core/utils/maintenance/edge_operations.py index 83334e73..d39594c7 100644 --- a/graphiti_core/utils/maintenance/edge_operations.py +++ b/graphiti_core/utils/maintenance/edge_operations.py @@ -163,6 +163,8 @@ async def dedupe_extracted_edges( if edge.uuid in duplicate_uuid_map: existing_uuid = duplicate_uuid_map[edge.uuid] existing_edge = edge_map[existing_uuid] + # Add current episode to the episodes list + existing_edge.episodes += edge.episodes edges.append(existing_edge) else: edges.append(edge)