Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search node centering #45

Merged
merged 8 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, hybrid_search
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search_utils import (
get_relevant_edges,
get_relevant_nodes,
Expand Down Expand Up @@ -515,7 +515,7 @@ async def add_episode_bulk(
except Exception as e:
raise e

async def search(self, query: str, num_results=10):
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
"""
Perform a hybrid search on the knowledge graph.

Expand All @@ -526,6 +526,8 @@ async def search(self, query: str, num_results=10):
----------
query : str
The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
num_results : int, optional
The maximum number of results to return. Defaults to 10.

Expand All @@ -543,22 +545,36 @@ async def search(self, query: str, num_results=10):
The search is performed using the current date and time as the reference
point for temporal relevance.
"""
search_config = SearchConfig(num_episodes=0, num_results=num_results)
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
search_config = SearchConfig(
num_episodes=0,
num_edges=num_results,
num_nodes=0,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
edges = (
await hybrid_search(
self.driver,
self.llm_client.get_embedder(),
query,
datetime.now(),
search_config,
center_node_uuid,
)
).edges

facts = [edge.fact for edge in edges]

return facts

async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)
55 changes: 39 additions & 16 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from datetime import datetime
from enum import Enum
from time import time

from neo4j import AsyncDriver
Expand All @@ -28,6 +29,7 @@
edge_fulltext_search,
edge_similarity_search,
get_mentioned_nodes,
node_distance_reranker,
rrf,
)
from graphiti_core.utils import retrieve_episodes
Expand All @@ -36,12 +38,22 @@
logger = logging.getLogger(__name__)


class SearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'


class Reranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'


class SearchConfig(BaseModel):
num_results: int = 10
num_edges: int = 10
num_nodes: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = 'cosine'
text_search: str = 'BM25'
reranker: str = 'rrf'
search_methods: list[SearchMethod]
reranker: Reranker | None


class SearchResults(BaseModel):
Expand All @@ -51,7 +63,12 @@ class SearchResults(BaseModel):


async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()

Expand All @@ -65,11 +82,11 @@ async def hybrid_search(
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))

if config.text_search == 'BM25':
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)

if config.similarity_search == 'cosine':
if SearchMethod.cosine_similarity in config.search_methods:
query_text = query.replace('\n', ' ')
search_vector = (
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
Expand All @@ -80,19 +97,14 @@ async def hybrid_search(
similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)

if len(search_results) == 1:
edges = search_results[0]

elif len(search_results) > 1 and config.reranker != 'rrf':
if len(search_results) > 1 and config.reranker is None:
logger.exception('Multiple searches enabled without a reranker')
raise Exception('Multiple searches enabled without a reranker')

elif config.reranker == 'rrf':
else:
edge_uuid_map = {}
search_result_uuids = []

logger.info([[edge.fact for edge in result] for result in search_results])

for result in search_results:
result_uuids = []
for edge in result:
Expand All @@ -103,12 +115,23 @@ async def hybrid_search(

search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

reranked_uuids = rrf(search_result_uuids)
reranked_uuids: list[str] = []
if config.reranker == Reranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == Reranker.node_distance:
if center_node_uuid is None:
logger.exception('No center node provided for Node Distance reranker')
raise Exception('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, search_result_uuids, center_node_uuid
)

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)

context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
context = SearchResults(
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
)

end = time()

Expand Down
42 changes: 41 additions & 1 deletion graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ async def get_relevant_edges(

# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int)
scores: dict[str, float] = defaultdict(float)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const)
Expand All @@ -344,3 +344,43 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
sorted_uuids = [term[0] for term in scored_uuids]

return sorted_uuids


async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
scores: dict[str, float] = {}

for uuid in sorted_uuids:
# Find shortest path to center node
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01

for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']

if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])
else:
scores[uuid] = 1 / distance

# rerank on shortest distance
sorted_uuids.sort(reverse=True, key=lambda cur_uuid: scores[cur_uuid])

return sorted_uuids