Skip to content

Commit

Permalink
add opinionated search
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Aug 22, 2024
1 parent 4696de1 commit 704a032
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 35 deletions.
24 changes: 21 additions & 3 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from core.nodes import EntityNode, EpisodicNode
from core.edges import EntityEdge, EpisodicEdge
from core.search.search import search, SearchConfig
from core.search.search import SearchConfig, hybrid_search
from core.utils import (
build_episodic_edges,
retrieve_episodes,
Expand Down Expand Up @@ -295,7 +295,25 @@ async def add_episode_bulk(
except Exception as e:
raise e

async def search(self, query: str, timestamp: datetime, config: SearchConfig):
return await search(
async def search(self, query: str, num_results=10):
search_config = SearchConfig(num_episodes=0, num_results=num_results)
edges = (
await hybrid_search(
self.driver,
self.llm_client.client.embeddings,
query,
datetime.now(),
search_config,
)
)["edges"]

facts = [edge.fact for edge in edges]

return facts

async def hybrid_search(
self, query: str, timestamp: datetime, config: SearchConfig
):
return await hybrid_search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
)
12 changes: 7 additions & 5 deletions core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SearchConfig(BaseModel):
reranker: str = "rrf"


async def search(
async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
start = time()
Expand All @@ -44,11 +44,11 @@ async def search(
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))

if config.text_search:
if config.text_search == "BM25":
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)

if config.similarity_search:
if config.similarity_search == "cosine":
query_text = query.replace("\n", " ")
search_vector = (
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
Expand All @@ -62,14 +62,16 @@ async def search(
if len(search_results) == 1:
edges = search_results[0]

elif len(search_results) > 1 and not config.reranker:
elif len(search_results) > 1 and not config.reranker == "rrf":
logger.exception("Multiple searches enabled without a reranker")
raise Exception("Multiple searches enabled without a reranker")

elif config.reranker:
elif config.reranker == "rrf":
edge_uuid_map = {}
search_result_uuids = []

logger.info([[edge.fact for edge in result] for result in search_results])

for result in search_results:
result_uuids = []
for edge in result:
Expand Down
2 changes: 1 addition & 1 deletion core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores[uuid] += 1 / (i + rank_const)

scored_uuids = [term for term in scores.items()]
scored_uuids.sort(key=lambda term: term[1])
scored_uuids.sort(reverse=True, key=lambda term: term[1])

sorted_uuids = [term[0] for term in scored_uuids]

Expand Down
36 changes: 10 additions & 26 deletions tests/tests_int_graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,11 @@ def setup_logging():
return logger


def format_context(context):
def format_context(facts):
formatted_string = ""
episodes = context["episodes"]
nodes = context["nodes"]
edges = context["edges"]

"Entities:\n"
for node in nodes:
formatted_string += f" UUID: {node.uuid}\n"
formatted_string += f" Name: {node.name}\n"
formatted_string += f" Summary: {node.summary}\n"

formatted_string += "Facts:\n"
for edge in edges:
formatted_string += f" - {edge.fact}\n"
formatted_string += "FACTS:\n"
for fact in facts:
formatted_string += f" - {fact}\n"
formatted_string += "\n"

return formatted_string.strip()
Expand All @@ -78,23 +68,17 @@ async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)

search_config = SearchConfig()
facts = await graphiti.search("Freakenomics guest")

context = await graphiti.search("Freakenomics guest", datetime.now(), search_config)
logger.info("\nQUERY: Freakenomics guest\n" + format_context(facts))

logger.info("\nQUERY: Freakenomics guest" + "\nRESULT:\n" + format_context(context))
facts = await graphiti.search("tania tetlow\n")

context = await graphiti.search("tania tetlow", datetime.now(), search_config)
logger.info("\nQUERY: Tania Tetlow\n