Skip to content

Commit

Permalink
Search node centering (#45)
Browse files Browse the repository at this point in the history
* add new search reranker and update search

* node distance reranking

* format

* rebase

* no need for enumerate

* mypy typing

* defaultdict update

* rrf prelim ranking
  • Loading branch information
prasmussen15 committed Aug 26, 2024
1 parent fc4bf3b commit 2d01e5d
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 22 deletions.
26 changes: 21 additions & 5 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, hybrid_search
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search_utils import (
get_relevant_edges,
get_relevant_nodes,
Expand Down Expand Up @@ -515,7 +515,7 @@ async def add_episode_bulk(
except Exception as e:
raise e

async def search(self, query: str, num_results=10):
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
"""
Perform a hybrid search on the knowledge graph.
Expand All @@ -526,6 +526,8 @@ async def search(self, query: str, num_results=10):
----------
query : str
The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
num_results : int, optional
The maximum number of results to return. Defaults to 10.
Expand All @@ -543,22 +545,36 @@ async def search(self, query: str, num_results=10):
The search is performed using the current date and time as the reference
point for temporal relevance.
"""
search_config = SearchConfig(num_episodes=0, num_results=num_results)
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
search_config = SearchConfig(
num_episodes=0,
num_edges=num_results,
num_nodes=0,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
edges = (
await hybrid_search(
self.driver,
self.llm_client.get_embedder(),
query,
datetime.now(),
search_config,
center_node_uuid,
)
).edges

facts = [edge.fact for edge in edges]

return facts

async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)
55 changes: 39 additions & 16 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from datetime import datetime
from enum import Enum
from time import time

from neo4j import AsyncDriver
Expand All @@ -28,6 +29,7 @@
edge_fulltext_search,
edge_similarity_search,
get_mentioned_nodes,
node_distance_reranker,
rrf,
)
from graphiti_core.utils import retrieve_episodes
Expand All @@ -36,12 +38,22 @@
logger = logging.getLogger(__name__)


class SearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'


class Reranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'


class SearchConfig(BaseModel):
num_results: int = 10
num_edges: int = 10
num_nodes: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = 'cosine'
text_search: str = 'BM25'
reranker: str = 'rrf'
search_methods: list[SearchMethod]
reranker: Reranker | None


class SearchResults(BaseModel):
Expand All @@ -51,7 +63,12 @@ class SearchResults(BaseModel):


async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()

Expand All @@ -65,11 +82,11 @@ async def hybrid_search(
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))

if config.text_search == 'BM25':
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)

if config.similarity_search == 'cosine':
if SearchMethod.cosine_similarity in config.search_methods:
query_text = query.replace('\n', ' ')
search_vector = (
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
Expand All @@ -80,19 +97,14 @@ async def hybrid_search(
similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)

if len(search_results) == 1:
edges = search_results[0]

elif len(search_results) > 1 and config.reranker != 'rrf':
if len(search_results) > 1 and config.reranker is None:
logger.exception('Multiple searches enabled without a reranker')
raise Exception('Multiple searches enabled without a reranker')

elif config.reranker == 'rrf':
else:
edge_uuid_map = {}
search_result_uuids = []

logger.info([[edge.fact for edge in result] for result in search_results])

for result in search_results:
result_uuids = []
for edge in result:
Expand All @@ -103,12 +115,23 @@ async def hybrid_search(

search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

reranked_uuids = rrf(search_result_uuids)
reranked_uuids: list[str] = []
if config.reranker == Reranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == Reranker.node_distance:
if center_node_uuid is None:
logger.exception('No center node provided for Node Distance reranker')
raise Exception('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, search_result_uuids, center_node_uuid
)

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)

context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
context = SearchResults(
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
)

end = time()

Expand Down
42 changes: 41 additions & 1 deletion graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ async def get_relevant_edges(

# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int)
scores: dict[str, float] = defaultdict(float)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const)
Expand All @@ -344,3 +344,43 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
sorted_uuids = [term[0] for term in scored_uuids]

return sorted_uuids


async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
scores: dict[str, float] = {}

for uuid in sorted_uuids:
# Find shortest path to center node
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01

for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']

if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
else:
scores[uuid] = 1 / distance

# rerank on shortest distance
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids

0 comments on commit 2d01e5d

Please sign in to comment.