Skip to content

Commit

Permalink
rrf prelim ranking
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Aug 26, 2024
1 parent 094a489 commit 517cedd
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 56 deletions.
36 changes: 18 additions & 18 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand Down Expand Up @@ -203,14 +203,14 @@ async def retrieve_episodes(
return await retrieve_episodes(self.driver, reference_time, last_n)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""
Process an episode and update the graph.
Expand Down Expand Up @@ -405,8 +405,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):
raise e

async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
self,
bulk_episodes: list[RawEpisode],
):
"""
Process multiple episodes in bulk and update the graph.
Expand Down Expand Up @@ -569,11 +569,11 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
return facts

async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
Expand Down
75 changes: 37 additions & 38 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):


async def edge_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# vector similarity search over embedded facts
records, _, _ = await driver.execute_query(
Expand Down Expand Up @@ -150,7 +150,7 @@ async def edge_similarity_search(


async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
Expand Down Expand Up @@ -184,7 +184,7 @@ async def entity_similarity_search(


async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
Expand Down Expand Up @@ -219,7 +219,7 @@ async def entity_fulltext_search(


async def edge_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityEdge]:
# fulltext search over facts
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
Expand Down Expand Up @@ -270,8 +270,8 @@ async def edge_fulltext_search(


async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
start = time()
relevant_nodes: list[EntityNode] = []
Expand Down Expand Up @@ -301,8 +301,8 @@ async def get_relevant_nodes(


async def get_relevant_edges(
edges: list[EntityEdge],
driver: AsyncDriver,
edges: list[EntityEdge],
driver: AsyncDriver,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
Expand Down Expand Up @@ -347,41 +347,40 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:


async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
scores: dict[str, float] = {}

for result in results:
for uuid in result:
# Find shortest paths
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01

for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']
for uuid in sorted_uuids:
# Find shortest path to center node
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01

if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
else:
scores[uuid] = 1 / distance
for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']

scored_uuids = [term for term in scores.items()]
scored_uuids.sort(reverse=True, key=lambda term: term[1])
if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
else:
scores[uuid] = 1 / distance

sorted_uuids = [term[0] for term in scored_uuids]
# rerank on shortest distance
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids

0 comments on commit 517cedd

Please sign in to comment.