Skip to content

Commit

Permalink
test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Aug 22, 2024
1 parent f8643f1 commit 4696de1
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
4 changes: 1 addition & 3 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,7 @@ async def add_episode_bulk(
except Exception as e:
raise e

async def search(
self, query: str, timestamp: datetime, config: SearchConfig
) -> list[tuple[EntityNode, list[EntityEdge]]]:
async def search(self, query: str, timestamp: datetime, config: SearchConfig):
return await search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
)
31 changes: 28 additions & 3 deletions core/search/search.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import asyncio
import logging
from datetime import datetime
from time import time

from neo4j import AsyncDriver
from pydantic import BaseModel

from core.edges import EntityEdge
from core.edges import EntityEdge, Edge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.search.search_utils import (
edge_similarity_search,
edge_fulltext_search,
Expand All @@ -29,7 +31,9 @@ class SearchConfig(BaseModel):

async def search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
):
) -> dict[str, [Node | Edge]]:
start = time()

episodes = []
nodes = []
edges = []
Expand Down Expand Up @@ -63,15 +67,36 @@ async def search(
raise Exception("Multiple searches enabled without a reranker")

elif config.reranker:
edge_uuid_map = {}
search_result_uuids = []

for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge

search_result_uuids.append(result_uuids)

search_result_uuids = [
[edge.uuid for edge in result] for result in search_results
]
edges.extend(rrf(search_result_uuids))

reranked_uuids = rrf(search_result_uuids)

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)

context = {
"episodes": episodes,
"nodes": nodes,
"edges": edges,
}

end = time()

logger.info(
f"search returned context for query {query} in {(end - start) * 1000} ms"
)

return context
10 changes: 5 additions & 5 deletions core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def edge_similarity_search(
CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector)
YIELD relationship AS r, score
MATCH (n)-[r:RELATES_TO]->(m)
RETURN DISTINCT
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
Expand Down Expand Up @@ -152,7 +152,7 @@ async def entity_similarity_search(
"""
CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector)
YIELD node AS n, score
RETURN DISTINCT
RETURN
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
Expand Down Expand Up @@ -186,7 +186,7 @@ async def entity_fulltext_search(
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN DISTINCT
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
Expand Down Expand Up @@ -224,7 +224,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN DISTINCT
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
Expand Down Expand Up @@ -329,7 +329,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / i + rank_const
scores[uuid] += 1 / (i + rank_const)

scored_uuids = [term for term in scores.items()]
scored_uuids.sort(key=lambda term: term[1])
Expand Down
35 changes: 24 additions & 11 deletions tests/graphiti_tests_int.py → tests/tests_int_graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pytest

from core.search.search import SearchConfig

pytestmark = pytest.mark.integration

import asyncio
Expand Down Expand Up @@ -53,14 +55,21 @@ def setup_logging():

def format_context(context):
formatted_string = ""
for uuid, data in context.items():
formatted_string += f"UUID: {uuid}\n"
formatted_string += f" Name: {data['name']}\n"
formatted_string += f" Summary: {data['summary']}\n"
formatted_string += " Facts:\n"
for fact in data["facts"]:
formatted_string += f" - {fact}\n"
formatted_string += "\n"
episodes = context["episodes"]
nodes = context["nodes"]
edges = context["edges"]

"Entities:\n"
for node in nodes:
formatted_string += f" UUID: {node.uuid}\n"
formatted_string += f" Name: {node.name}\n"
formatted_string += f" Summary: {node.summary}\n"

formatted_string += "Facts:\n"
for edge in edges:
formatted_string += f" - {edge.fact}\n"
formatted_string += "\n"

return formatted_string.strip()


Expand All @@ -69,15 +78,19 @@ async def test_graphiti_init():
logger = setup_logging()
graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None)

context = await graphiti.search("Freakenomics guest")
search_config = SearchConfig()

context = await graphiti.search("Freakenomics guest", datetime.now(), search_config)

logger.info("\nQUERY: Freakenomics guest" + "\nRESULT:\n" + format_context(context))

context = await graphiti.search("tania tetlow")
context = await graphiti.search("tania tetlow", datetime.now(), search_config)

logger.info("\nQUERY: Tania Tetlow" + "\nRESULT:\n" + format_context(context))

context = await graphiti.search("issues with higher ed")
context = await graphiti.search(
"issues with higher ed", datetime.now(), search_config
)

logger.info(
"\nQUERY: issues with higher ed" + "\nRESULT:\n" + format_context(context)
Expand Down

0 comments on commit 4696de1

Please sign in to comment.