Skip to content

Commit

Permalink
Improve node distance reranker speed (#107)
Browse files Browse the repository at this point in the history
* much faster

* clean up code

* variable rename
  • Loading branch information
prasmussen15 committed Sep 12, 2024
1 parent 8085b52 commit 85cf8e5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def main(use_bulk: bool = True):
messages = parse_podcast_messages()

if not use_bulk:
for i, message in enumerate(messages[3:14]):
for i, message in enumerate(messages[3:130]):
await client.add_episode(
name=f'Message {i}',
episode_body=f'{message.speaker_name} ({message.role}): {message.content}',
Expand Down
49 changes: 27 additions & 22 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,34 +496,39 @@ async def node_distance_reranker(
sorted_uuids = rrf(results)
scores: dict[str, float] = {}

for uuid in sorted_uuids:
# Find the shortest path to center node
records, _, _ = await driver.execute_query(
"""
# Find the shortest path to center node
query = Query("""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO*1..10]->(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01
MATCH p = SHORTEST 1 (center:Entity {uuid: $center_uuid})-[:RELATES_TO]-+(n:Entity {uuid: source.uuid})
RETURN length(p) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""")

for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']
path_results = await asyncio.gather(
*[
driver.execute_query(
query,
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
for uuid in sorted_uuids
]
)

for uuid, result in zip(sorted_uuids, path_results):
records = result[0]
record = records[0] if len(records) > 0 else None
distance: float = record['score'] if record is not None else float('inf')
if record is not None and (
record['source_uuid'] == center_node_uuid or record['target_uuid'] == center_node_uuid
):
distance = 0

if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
scores[uuid] = min(distance, scores[uuid])
else:
scores[uuid] = 1 / distance
scores[uuid] = distance

# rerank on shortest distance
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])
sorted_uuids.sort(key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids
5 changes: 0 additions & 5 deletions tests/test_graphiti_int.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ def format_context(facts):
async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD)
await graphiti.build_communities()

edges = await graphiti.search('Freakenomics guest', group_ids=['1'])

logger.info('\nQUERY: Freakenomics guest\n' + format_context([edge.fact for edge in edges]))

edges = await graphiti.search('tania tetlow', group_ids=['1'])

Expand Down

0 comments on commit 85cf8e5

Please sign in to comment.