From ac5934d5df581f22ae008efef53f2a0cf24ef7d0 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 15:10:14 -0400 Subject: [PATCH 1/8] add new search reranker and update search --- graphiti_core/graphiti.py | 24 ++++++++++--- graphiti_core/search/search.py | 53 +++++++++++++++++++--------- graphiti_core/search/search_utils.py | 43 ++++++++++++++++++++++ 3 files changed, 99 insertions(+), 21 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index ec038f1..e657d46 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -26,7 +26,7 @@ from graphiti_core.edges import EntityEdge, EpisodicEdge from graphiti_core.llm_client import LLMClient, OpenAIClient from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode -from graphiti_core.search.search import SearchConfig, hybrid_search +from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search from graphiti_core.search.search_utils import ( get_relevant_edges, get_relevant_nodes, @@ -515,7 +515,7 @@ async def add_episode_bulk( except Exception as e: raise e - async def search(self, query: str, num_results=10): + async def search(self, query: str, center_node_uuid: str | None = None, num_results=10): """ Perform a hybrid search on the knowledge graph. @@ -543,7 +543,14 @@ async def search(self, query: str, num_results=10): The search is performed using the current date and time as the reference point for temporal relevance. """ - search_config = SearchConfig(num_episodes=0, num_results=num_results) + reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance + search_config = SearchConfig( + num_episodes=0, + num_edges=num_results, + num_nodes=0, + search_mathods=[SearchMethod.bm25, SearchMethod.cosine_similarity], + reranker=reranker, + ) edges = ( await hybrid_search( self.driver, @@ -551,6 +558,7 @@ async def search(self, query: str, num_results=10): query, datetime.now(), search_config, + center_node_uuid, ) ).edges @@ -558,7 +566,13 @@ async def search(self, query: str, num_results=10): return facts - async def _search(self, query: str, timestamp: datetime, config: SearchConfig): + async def _search( + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, + ): return await hybrid_search( - self.driver, self.llm_client.get_embedder(), query, timestamp, config + self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid ) diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 956ae65..7323447 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -16,6 +16,7 @@ import logging from datetime import datetime +from enum import Enum from time import time from neo4j import AsyncDriver @@ -28,6 +29,7 @@ edge_fulltext_search, edge_similarity_search, get_mentioned_nodes, + node_distance_reranker, rrf, ) from graphiti_core.utils import retrieve_episodes @@ -36,12 +38,22 @@ logger = logging.getLogger(__name__) +class SearchMethod(Enum): + cosine_similarity = 'cosine_similarity' + bm25 = 'bm25' + + +class Reranker(Enum): + rrf = 'reciprocal_rank_fusion' + node_distance = 'node_distance' + + class SearchConfig(BaseModel): - num_results: int = 10 + num_edges: int = 10 + num_nodes: int = 10 num_episodes: int = EPISODE_WINDOW_LEN - similarity_search: str = 'cosine' - text_search: str = 'BM25' - reranker: str = 'rrf' + search_methods: list[SearchMethod] + reranker: Reranker | None class SearchResults(BaseModel): @@ -51,7 +63,12 @@ class SearchResults(BaseModel): async def hybrid_search( - driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig + driver: AsyncDriver, + embedder, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() @@ -65,11 +82,11 @@ async def hybrid_search( episodes.extend(await retrieve_episodes(driver, timestamp)) nodes.extend(await get_mentioned_nodes(driver, episodes)) - if config.text_search == 'BM25': + if SearchMethod.bm25 in config.search_methods: text_search = await edge_fulltext_search(query, driver) search_results.append(text_search) - if config.similarity_search == 'cosine': + if SearchMethod.cosine_similarity in config.search_methods: query_text = query.replace('\n', ' ') search_vector = ( (await embedder.create(input=[query_text], model='text-embedding-3-small')) @@ -80,19 +97,14 @@ async def hybrid_search( similarity_search = await edge_similarity_search(search_vector, driver) search_results.append(similarity_search) - if len(search_results) == 1: - edges = search_results[0] - - elif len(search_results) > 1 and config.reranker != 'rrf': + if len(search_results) > 1 and config.reranker is None: logger.exception('Multiple searches enabled without a reranker') raise Exception('Multiple searches enabled without a reranker') - elif config.reranker == 'rrf': + else: edge_uuid_map = {} search_result_uuids = [] - logger.info([[edge.fact for edge in result] for result in search_results]) - for result in search_results: result_uuids = [] for edge in result: @@ -103,12 +115,21 @@ async def hybrid_search( search_result_uuids = [[edge.uuid for edge in result] for result in search_results] - reranked_uuids = rrf(search_result_uuids) + reranked_uuids: list[str] = [] + if config.reranker == Reranker.rrf: + reranked_uuids = rrf(search_result_uuids) + elif config.reranker == Reranker.node_distance: + if center_node_uuid is None: + logger.exception('No center node provided for Node Distance reranker') + raise Exception('No center node provided for Node Distance reranker') + reranked_uuids = node_distance_reranker(driver, search_result_uuids, center_node_uuid) reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] edges.extend(reranked_edges) - context = SearchResults(episodes=episodes, nodes=nodes, edges=edges) + context = SearchResults( + episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges] + ) end = time() diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index d73ea5e..164809d 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -344,3 +344,46 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: sorted_uuids = [term[0] for term in scored_uuids] return sorted_uuids + + +def node_distance_reranker( + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str +) -> list[str]: + scores: dict[str, int] = defaultdict(int) + for result in results: + for i, uuid in enumerate(result): + # Find shortest paths + records, _, _ = driver.execute_query( + """ + MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) YIELD source, target + MATCH (center:Entity {uuid: $center_uuid} YIELD center + CALL gds.graph.project( + 'shortest_path_source', + source, + center + ) YIELD source_total_cost + CALL gds.graph.project( + 'shortest_path_target', + target, + center + ) YIELD target_total_cost + RETURN min(source_total_cost, target_total_cost) AS score + """, + edge_uuid=uuid, + center_uuid=center_node_uuid, + ) + + distance = 0.01 + + for record in records: + if record['score'] > distance: + distance = record['score'] + + scores[uuid] += 1 / distance + + scored_uuids = [term for term in scores.items()] + scored_uuids.sort(reverse=True, key=lambda term: term[1]) + + sorted_uuids = [term[0] for term in scored_uuids] + + return sorted_uuids From 60ac601ee16b9ea47b70d2421a1e029df8055e91 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 17:26:10 -0400 Subject: [PATCH 2/8] node distance reranking --- graphiti_core/graphiti.py | 38 +++++++++---------- graphiti_core/search/search.py | 14 +++---- graphiti_core/search/search_utils.py | 57 +++++++++++++--------------- 3 files changed, 52 insertions(+), 57 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index e657d46..26c6d7a 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -173,9 +173,9 @@ async def build_indices_and_constraints(self): await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -203,14 +203,14 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + success_callback: Callable | None = None, + error_callback: Callable | None = None, ): """ Process an episode and update the graph. @@ -405,8 +405,8 @@ async def add_episode_endpoint(episode_data: EpisodeData): raise e async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], + self, + bulk_episodes: list[RawEpisode], ): """ Process multiple episodes in bulk and update the graph. @@ -548,7 +548,7 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu num_episodes=0, num_edges=num_results, num_nodes=0, - search_mathods=[SearchMethod.bm25, SearchMethod.cosine_similarity], + search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity], reranker=reranker, ) edges = ( @@ -567,11 +567,11 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu return facts async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 7323447..0892123 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -63,12 +63,12 @@ class SearchResults(BaseModel): async def hybrid_search( - driver: AsyncDriver, - embedder, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + driver: AsyncDriver, + embedder, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() @@ -122,7 +122,7 @@ async def hybrid_search( if center_node_uuid is None: logger.exception('No center node provided for Node Distance reranker') raise Exception('No center node provided for Node Distance reranker') - reranked_uuids = node_distance_reranker(driver, search_result_uuids, center_node_uuid) + reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid) reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] edges.extend(reranked_edges) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 164809d..a111db6 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -101,7 +101,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # vector similarity search over embedded facts records, _, _ = await driver.execute_query( @@ -150,7 +150,7 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -184,7 +184,7 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -219,7 +219,7 @@ async def entity_fulltext_search( async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # fulltext search over facts fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -270,8 +270,8 @@ async def edge_fulltext_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: start = time() relevant_nodes: list[EntityNode] = [] @@ -301,8 +301,8 @@ async def get_relevant_nodes( async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, + edges: list[EntityEdge], + driver: AsyncDriver, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -346,40 +346,35 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: return sorted_uuids -def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str +async def node_distance_reranker( + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: - scores: dict[str, int] = defaultdict(int) + scores: dict[str, int] = {} + for result in results: for i, uuid in enumerate(result): # Find shortest paths - records, _, _ = driver.execute_query( - """ - MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) YIELD source, target - MATCH (center:Entity {uuid: $center_uuid} YIELD center - CALL gds.graph.project( - 'shortest_path_source', - source, - center - ) YIELD source_total_cost - CALL gds.graph.project( - 'shortest_path_target', - target, - center - ) YIELD target_total_cost - RETURN min(source_total_cost, target_total_cost) AS score + records, _, _ = await driver.execute_query( + """ + MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) + MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity) + WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid] + RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid """, edge_uuid=uuid, center_uuid=center_node_uuid, ) - distance = 0.01 for record in records: - if record['score'] > distance: - distance = record['score'] - - scores[uuid] += 1 / distance + if record["source_uuid"] == center_node_uuid or record["target_uuid"] == center_node_uuid: + continue + distance = record["score"] + + if uuid in scores: + scores[uuid] = min(1 / distance, scores[uuid]) + else: + scores[uuid] = 1 / distance scored_uuids = [term for term in scores.items()] scored_uuids.sort(reverse=True, key=lambda term: term[1]) From 33c3f7729020fb6677d0962d76a5337fd3fc9714 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 17:26:55 -0400 Subject: [PATCH 3/8] format --- graphiti_core/graphiti.py | 36 ++++++++++++++-------------- graphiti_core/search/search.py | 16 +++++++------ graphiti_core/search/search_utils.py | 25 ++++++++++--------- 3 files changed, 41 insertions(+), 36 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 26c6d7a..842a1db 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -173,9 +173,9 @@ async def build_indices_and_constraints(self): await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -203,14 +203,14 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + success_callback: Callable | None = None, + error_callback: Callable | None = None, ): """ Process an episode and update the graph. @@ -405,8 +405,8 @@ async def add_episode_endpoint(episode_data: EpisodeData): raise e async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], + self, + bulk_episodes: list[RawEpisode], ): """ Process multiple episodes in bulk and update the graph. @@ -567,11 +567,11 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu return facts async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid diff --git a/graphiti_core/search/search.py b/graphiti_core/search/search.py index 0892123..0311122 100644 --- a/graphiti_core/search/search.py +++ b/graphiti_core/search/search.py @@ -63,12 +63,12 @@ class SearchResults(BaseModel): async def hybrid_search( - driver: AsyncDriver, - embedder, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + driver: AsyncDriver, + embedder, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ) -> SearchResults: start = time() @@ -122,7 +122,9 @@ async def hybrid_search( if center_node_uuid is None: logger.exception('No center node provided for Node Distance reranker') raise Exception('No center node provided for Node Distance reranker') - reranked_uuids = await node_distance_reranker(driver, search_result_uuids, center_node_uuid) + reranked_uuids = await node_distance_reranker( + driver, search_result_uuids, center_node_uuid + ) reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids] edges.extend(reranked_edges) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index a111db6..9ddd3a1 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -101,7 +101,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # vector similarity search over embedded facts records, _, _ = await driver.execute_query( @@ -150,7 +150,7 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -184,7 +184,7 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -219,7 +219,7 @@ async def entity_fulltext_search( async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # fulltext search over facts fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -270,8 +270,8 @@ async def edge_fulltext_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: start = time() relevant_nodes: list[EntityNode] = [] @@ -301,8 +301,8 @@ async def get_relevant_nodes( async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, + edges: list[EntityEdge], + driver: AsyncDriver, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -347,7 +347,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: scores: dict[str, int] = {} @@ -367,9 +367,12 @@ async def node_distance_reranker( distance = 0.01 for record in records: - if record["source_uuid"] == center_node_uuid or record["target_uuid"] == center_node_uuid: + if ( + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid + ): continue - distance = record["score"] + distance = record['score'] if uuid in scores: scores[uuid] = min(1 / distance, scores[uuid]) From c5752e9e1ba4565cdd68bcb26df5fda2a3a9a35e Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 17:33:22 -0400 Subject: [PATCH 4/8] rebase --- graphiti_core/graphiti.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 842a1db..9655f31 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -173,9 +173,9 @@ async def build_indices_and_constraints(self): await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -203,14 +203,14 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + success_callback: Callable | None = None, + error_callback: Callable | None = None, ): """ Process an episode and update the graph. @@ -405,8 +405,8 @@ async def add_episode_endpoint(episode_data: EpisodeData): raise e async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], + self, + bulk_episodes: list[RawEpisode], ): """ Process multiple episodes in bulk and update the graph. @@ -526,6 +526,8 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu ---------- query : str The search query string. + center_node_uuid: str, optional + Facts will be reranked based on proximity to this node num_results : int, optional The maximum number of results to return. Defaults to 10. @@ -567,11 +569,11 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu return facts async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid From d901d887315689ea113636dc2c7f09131debbad1 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 17:36:03 -0400 Subject: [PATCH 5/8] no need for enumerate --- graphiti_core/search/search_utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 9ddd3a1..feeb758 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -101,7 +101,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # vector similarity search over embedded facts records, _, _ = await driver.execute_query( @@ -150,7 +150,7 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -184,7 +184,7 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -219,7 +219,7 @@ async def entity_fulltext_search( async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # fulltext search over facts fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -270,8 +270,8 @@ async def edge_fulltext_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: start = time() relevant_nodes: list[EntityNode] = [] @@ -301,8 +301,8 @@ async def get_relevant_nodes( async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, + edges: list[EntityEdge], + driver: AsyncDriver, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -347,12 +347,12 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: scores: dict[str, int] = {} for result in results: - for i, uuid in enumerate(result): + for uuid in result: # Find shortest paths records, _, _ = await driver.execute_query( """ @@ -368,8 +368,8 @@ async def node_distance_reranker( for record in records: if ( - record['source_uuid'] == center_node_uuid - or record['target_uuid'] == center_node_uuid + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid ): continue distance = record['score'] From 13aa9fac67cc838722fc6ba51cbff1b7d9743239 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 17:50:26 -0400 Subject: [PATCH 6/8] mypy typing --- graphiti_core/search/search_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index feeb758..55530d7 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -333,7 +333,7 @@ async def get_relevant_edges( # takes in a list of rankings of uuids def rrf(results: list[list[str]], rank_const=1) -> list[str]: - scores: dict[str, int] = defaultdict(int) + scores: dict[str, float] = defaultdict(int) for result in results: for i, uuid in enumerate(result): scores[uuid] += 1 / (i + rank_const) @@ -349,7 +349,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: - scores: dict[str, int] = {} + scores: dict[str, float] = {} for result in results: for uuid in result: From 094a4897d9a0f3a542442875c4d1ee4b58b5cc4e Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 17:56:54 -0400 Subject: [PATCH 7/8] defaultdict update --- graphiti_core/search/search_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 55530d7..5acae5a 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -333,7 +333,7 @@ async def get_relevant_edges( # takes in a list of rankings of uuids def rrf(results: list[list[str]], rank_const=1) -> list[str]: - scores: dict[str, float] = defaultdict(int) + scores: dict[str, float] = defaultdict(float) for result in results: for i, uuid in enumerate(result): scores[uuid] += 1 / (i + rank_const) From 517cedd56093cb86f566227a0f35d0ad1e342d32 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Mon, 26 Aug 2024 18:31:51 -0400 Subject: [PATCH 8/8] rrf prelim ranking --- graphiti_core/graphiti.py | 36 ++++++------- graphiti_core/search/search_utils.py | 75 ++++++++++++++-------------- 2 files changed, 55 insertions(+), 56 deletions(-) diff --git a/graphiti_core/graphiti.py b/graphiti_core/graphiti.py index 9655f31..6ff5b52 100644 --- a/graphiti_core/graphiti.py +++ b/graphiti_core/graphiti.py @@ -173,9 +173,9 @@ async def build_indices_and_constraints(self): await build_indices_and_constraints(self.driver) async def retrieve_episodes( - self, - reference_time: datetime, - last_n: int = EPISODE_WINDOW_LEN, + self, + reference_time: datetime, + last_n: int = EPISODE_WINDOW_LEN, ) -> list[EpisodicNode]: """ Retrieve the last n episodic nodes from the graph. @@ -203,14 +203,14 @@ async def retrieve_episodes( return await retrieve_episodes(self.driver, reference_time, last_n) async def add_episode( - self, - name: str, - episode_body: str, - source_description: str, - reference_time: datetime, - source: EpisodeType = EpisodeType.message, - success_callback: Callable | None = None, - error_callback: Callable | None = None, + self, + name: str, + episode_body: str, + source_description: str, + reference_time: datetime, + source: EpisodeType = EpisodeType.message, + success_callback: Callable | None = None, + error_callback: Callable | None = None, ): """ Process an episode and update the graph. @@ -405,8 +405,8 @@ async def add_episode_endpoint(episode_data: EpisodeData): raise e async def add_episode_bulk( - self, - bulk_episodes: list[RawEpisode], + self, + bulk_episodes: list[RawEpisode], ): """ Process multiple episodes in bulk and update the graph. @@ -569,11 +569,11 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu return facts async def _search( - self, - query: str, - timestamp: datetime, - config: SearchConfig, - center_node_uuid: str | None = None, + self, + query: str, + timestamp: datetime, + config: SearchConfig, + center_node_uuid: str | None = None, ): return await hybrid_search( self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid diff --git a/graphiti_core/search/search_utils.py b/graphiti_core/search/search_utils.py index 5acae5a..e9d658e 100644 --- a/graphiti_core/search/search_utils.py +++ b/graphiti_core/search/search_utils.py @@ -101,7 +101,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver): async def edge_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # vector similarity search over embedded facts records, _, _ = await driver.execute_query( @@ -150,7 +150,7 @@ async def edge_similarity_search( async def entity_similarity_search( - search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # vector similarity search over entity names records, _, _ = await driver.execute_query( @@ -184,7 +184,7 @@ async def entity_similarity_search( async def entity_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityNode]: # BM25 search to get top nodes fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -219,7 +219,7 @@ async def entity_fulltext_search( async def edge_fulltext_search( - query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT + query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT ) -> list[EntityEdge]: # fulltext search over facts fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~' @@ -270,8 +270,8 @@ async def edge_fulltext_search( async def get_relevant_nodes( - nodes: list[EntityNode], - driver: AsyncDriver, + nodes: list[EntityNode], + driver: AsyncDriver, ) -> list[EntityNode]: start = time() relevant_nodes: list[EntityNode] = [] @@ -301,8 +301,8 @@ async def get_relevant_nodes( async def get_relevant_edges( - edges: list[EntityEdge], - driver: AsyncDriver, + edges: list[EntityEdge], + driver: AsyncDriver, ) -> list[EntityEdge]: start = time() relevant_edges: list[EntityEdge] = [] @@ -347,41 +347,40 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]: async def node_distance_reranker( - driver: AsyncDriver, results: list[list[str]], center_node_uuid: str + driver: AsyncDriver, results: list[list[str]], center_node_uuid: str ) -> list[str]: + # use rrf as a preliminary ranker + sorted_uuids = rrf(results) scores: dict[str, float] = {} - for result in results: - for uuid in result: - # Find shortest paths - records, _, _ = await driver.execute_query( - """ - MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) - MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity) - WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid] - RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid - """, - edge_uuid=uuid, - center_uuid=center_node_uuid, - ) - distance = 0.01 - - for record in records: - if ( - record['source_uuid'] == center_node_uuid - or record['target_uuid'] == center_node_uuid - ): - continue - distance = record['score'] + for uuid in sorted_uuids: + # Find shortest path to center node + records, _, _ = await driver.execute_query( + """ + MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity) + MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity) + WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid] + RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid + """, + edge_uuid=uuid, + center_uuid=center_node_uuid, + ) + distance = 0.01 - if uuid in scores: - scores[uuid] = min(1 / distance, scores[uuid]) - else: - scores[uuid] = 1 / distance + for record in records: + if ( + record['source_uuid'] == center_node_uuid + or record['target_uuid'] == center_node_uuid + ): + continue + distance = record['score'] - scored_uuids = [term for term in scores.items()] - scored_uuids.sort(reverse=True, key=lambda term: term[1]) + if uuid in scores: + scores[uuid] = min(1 / distance, scores[uuid]) + else: + scores[uuid] = 1 / distance - sorted_uuids = [term[0] for term in scored_uuids] + # rerank on shortest distance + sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid]) return sorted_uuids