Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Search node centering #45

Merged
merged 8 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 34 additions & 18 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from graphiti_core.edges import EntityEdge, EpisodicEdge
from graphiti_core.llm_client import LLMClient, OpenAIClient
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import SearchConfig, hybrid_search
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search_utils import (
get_relevant_edges,
get_relevant_nodes,
Expand Down Expand Up @@ -173,9 +173,9 @@ async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand Down Expand Up @@ -203,14 +203,14 @@ async def retrieve_episodes(
return await retrieve_episodes(self.driver, reference_time, last_n)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""
Process an episode and update the graph.
Expand Down Expand Up @@ -405,8 +405,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):
raise e

async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
self,
bulk_episodes: list[RawEpisode],
):
"""
Process multiple episodes in bulk and update the graph.
Expand Down Expand Up @@ -515,7 +515,7 @@ async def add_episode_bulk(
except Exception as e:
raise e

async def search(self, query: str, num_results=10):
async def search(self, query: str, center_node_uuid: str | None = None, num_results=10):
"""
Perform a hybrid search on the knowledge graph.

Expand All @@ -526,6 +526,8 @@ async def search(self, query: str, num_results=10):
----------
query : str
The search query string.
center_node_uuid: str, optional
Facts will be reranked based on proximity to this node
num_results : int, optional
The maximum number of results to return. Defaults to 10.

Expand All @@ -543,22 +545,36 @@ async def search(self, query: str, num_results=10):
The search is performed using the current date and time as the reference
point for temporal relevance.
"""
search_config = SearchConfig(num_episodes=0, num_results=num_results)
reranker = Reranker.rrf if center_node_uuid is None else Reranker.node_distance
search_config = SearchConfig(
num_episodes=0,
num_edges=num_results,
num_nodes=0,
search_methods=[SearchMethod.bm25, SearchMethod.cosine_similarity],
reranker=reranker,
)
edges = (
await hybrid_search(
self.driver,
self.llm_client.get_embedder(),
query,
datetime.now(),
search_config,
center_node_uuid,
)
).edges

facts = [edge.fact for edge in edges]

return facts

async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)
55 changes: 39 additions & 16 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from datetime import datetime
from enum import Enum
from time import time

from neo4j import AsyncDriver
Expand All @@ -28,6 +29,7 @@
edge_fulltext_search,
edge_similarity_search,
get_mentioned_nodes,
node_distance_reranker,
rrf,
)
from graphiti_core.utils import retrieve_episodes
Expand All @@ -36,12 +38,22 @@
logger = logging.getLogger(__name__)


class SearchMethod(Enum):
cosine_similarity = 'cosine_similarity'
bm25 = 'bm25'


class Reranker(Enum):
rrf = 'reciprocal_rank_fusion'
node_distance = 'node_distance'


class SearchConfig(BaseModel):
num_results: int = 10
num_edges: int = 10
num_nodes: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = 'cosine'
text_search: str = 'BM25'
reranker: str = 'rrf'
search_methods: list[SearchMethod]
reranker: Reranker | None


class SearchResults(BaseModel):
Expand All @@ -51,7 +63,12 @@ class SearchResults(BaseModel):


async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()

Expand All @@ -65,11 +82,11 @@ async def hybrid_search(
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))

if config.text_search == 'BM25':
if SearchMethod.bm25 in config.search_methods:
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)

if config.similarity_search == 'cosine':
if SearchMethod.cosine_similarity in config.search_methods:
query_text = query.replace('\n', ' ')
search_vector = (
(await embedder.create(input=[query_text], model='text-embedding-3-small'))
Expand All @@ -80,19 +97,14 @@ async def hybrid_search(
similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)

if len(search_results) == 1:
edges = search_results[0]

elif len(search_results) > 1 and config.reranker != 'rrf':
if len(search_results) > 1 and config.reranker is None:
logger.exception('Multiple searches enabled without a reranker')
raise Exception('Multiple searches enabled without a reranker')

elif config.reranker == 'rrf':
else:
edge_uuid_map = {}
search_result_uuids = []

logger.info([[edge.fact for edge in result] for result in search_results])

for result in search_results:
result_uuids = []
for edge in result:
Expand All @@ -103,12 +115,23 @@ async def hybrid_search(

search_result_uuids = [[edge.uuid for edge in result] for result in search_results]

reranked_uuids = rrf(search_result_uuids)
reranked_uuids: list[str] = []
if config.reranker == Reranker.rrf:
reranked_uuids = rrf(search_result_uuids)
elif config.reranker == Reranker.node_distance:
if center_node_uuid is None:
logger.exception('No center node provided for Node Distance reranker')
raise Exception('No center node provided for Node Distance reranker')
reranked_uuids = await node_distance_reranker(
driver, search_result_uuids, center_node_uuid
)

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)

context = SearchResults(episodes=episodes, nodes=nodes, edges=edges)
context = SearchResults(
episodes=episodes, nodes=nodes[: config.num_nodes], edges=edges[: config.num_edges]
)

end = time()

Expand Down
41 changes: 41 additions & 0 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,44 @@
sorted_uuids = [term[0] for term in scored_uuids]

return sorted_uuids


async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
scores: dict[str, int] = {}

for result in results:
for i, uuid in enumerate(result):

Check failure on line 355 in graphiti_core/search/search_utils.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (B007)

graphiti_core/search/search_utils.py:355:13: B007 Loop control variable `i` not used within loop body
# Find shortest paths
records, _, _ = await driver.execute_query(
"""
MATCH (source:Entity)-[r:RELATES_TO {uuid: $edge_uuid}]->(target:Entity)
MATCH p = SHORTEST 1 (center:Entity)-[:RELATES_TO]-+(n:Entity)
WHERE center.uuid = $center_uuid AND n.uuid IN [source.uuid, target.uuid]
RETURN min(length(p)) AS score, source.uuid AS source_uuid, target.uuid AS target_uuid
""",
edge_uuid=uuid,
center_uuid=center_node_uuid,
)
distance = 0.01
prasmussen15 marked this conversation as resolved.
Show resolved Hide resolved

for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']

if uuid in scores:
scores[uuid] = min(1 / distance, scores[uuid])

Check failure on line 378 in graphiti_core/search/search_utils.py

View workflow job for this annotation

GitHub Actions / mypy

assignment

Incompatible types in assignment (expression has type "float", target has type "int")
else:
scores[uuid] = 1 / distance

Check failure on line 380 in graphiti_core/search/search_utils.py

View workflow job for this annotation

GitHub Actions / mypy

assignment

Incompatible types in assignment (expression has type "float", target has type "int")

scored_uuids = [term for term in scores.items()]
scored_uuids.sort(reverse=True, key=lambda term: term[1])

sorted_uuids = [term[0] for term in scored_uuids]

return sorted_uuids
Loading