Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

search updates #14

Merged
merged 4 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 19 additions & 88 deletions core/graphiti.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, LiteralString
from typing import Callable
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
from time import time
import os

from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.nodes import EntityNode, EpisodicNode
from core.edges import EntityEdge, EpisodicEdge
from core.search.search import search, SearchConfig
from core.utils import (
build_episodic_edges,
retrieve_episodes,
Expand All @@ -19,22 +19,21 @@
BulkEpisode,
extract_nodes_and_edges_bulk,
retrieve_previous_episodes_bulk,
compress_nodes,
dedupe_nodes_bulk,
resolve_edge_pointers,
dedupe_edges_bulk,
)
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
)
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
invalidate_edges,
prepare_edges_for_invalidation,
)
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
bfs,
from core.search.search_utils import (
get_relevant_nodes,
get_relevant_edges,
)
Expand Down Expand Up @@ -64,10 +63,13 @@ def __init__(
def close(self):
self.driver.close()

async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
Expand Down Expand Up @@ -103,9 +105,7 @@ async def add_episode(
embedder = self.llm_client.client.embeddings
now = datetime.now()

previous_episodes = await self.retrieve_episodes(
reference_time, last_n=EPISODE_WINDOW_LEN
)
previous_episodes = await self.retrieve_episodes(reference_time)
episode = EpisodicNode(
name=name,
labels=[],
Expand Down Expand Up @@ -220,80 +220,6 @@ async def add_episode(
else:
raise e

async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE CONSTRAINT entity_name IF NOT EXISTS
FOR (n:Entity) REQUIRE n.name IS UNIQUE
""",
"""
CREATE CONSTRAINT edge_facts IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
""",
]

await asyncio.gather(
*[self.driver.execute_query(query) for query in index_queries]
)

async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
text = query.replace("\n", " ")
search_vector = (
(
await self.llm_client.client.embeddings.create(
input=[text], model="text-embedding-3-small"
)
)
.data[0]
.embedding[:EMBEDDING_DIM]
)

edges = await edge_similarity_search(search_vector, self.driver)
nodes = await entity_fulltext_search(query, self.driver)

node_ids = [node.uuid for node in nodes]

for edge in edges:
node_ids.append(edge.source_node_uuid)
node_ids.append(edge.target_node_uuid)

node_ids = list(dict.fromkeys(node_ids))

context = await bfs(node_ids, self.driver)

return context

async def add_episode_bulk(
self,
bulk_episodes: list[BulkEpisode],
Expand Down Expand Up @@ -368,3 +294,8 @@ async def add_episode_bulk(

except Exception as e:
raise e

async def search(self, query: str, timestamp: datetime, config: SearchConfig):
return await search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
)
2 changes: 1 addition & 1 deletion core/prompts/dedupe_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def node_list(context: dict[str, any]) -> list[Message]:

Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All dupolicate names should be grouped together in the same list
2. All duplicate names should be grouped together in the same list

Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response
Expand Down
File renamed without changes.
102 changes: 102 additions & 0 deletions core/search/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import asyncio
import logging
from datetime import datetime
from time import time

from neo4j import AsyncDriver
from pydantic import BaseModel

from core.edges import EntityEdge, Edge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.search.search_utils import (
edge_similarity_search,
edge_fulltext_search,
get_mentioned_nodes,
rrf,
)
from core.utils import retrieve_episodes
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN

logger = logging.getLogger(__name__)


class SearchConfig(BaseModel):
num_results: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = "cosine"
text_search: str = "BM25"
reranker: str = "rrf"


async def search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
start = time()

episodes = []
nodes = []
edges = []

search_results = []

if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))

if config.text_search:
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)

if config.similarity_search:
query_text = query.replace("\n", " ")
search_vector = (
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
.data[0]
.embedding[:EMBEDDING_DIM]
)

similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)

if len(search_results) == 1:
edges = search_results[0]

elif len(search_results) > 1 and not config.reranker:
logger.exception("Multiple searches enabled without a reranker")
raise Exception("Multiple searches enabled without a reranker")

elif config.reranker:
edge_uuid_map = {}
search_result_uuids = []

for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge

search_result_uuids.append(result_uuids)

search_result_uuids = [
[edge.uuid for edge in result] for result in search_results
]

reranked_uuids = rrf(search_result_uuids)

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)

context = {
"episodes": episodes,
"nodes": nodes,
"edges": edges,
}

end = time()

logger.info(
f"search returned context for query {query} in {(end - start) * 1000} ms"
)

return context
58 changes: 52 additions & 6 deletions core/utils/search/search_utils.py → core/search/search_utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
import asyncio
import logging
from collections import defaultdict
from datetime import datetime
from time import time

from neo4j import AsyncDriver

from core.edges import EntityEdge
from core.nodes import EntityNode
from core.nodes import EntityNode, EpisodicNode

logger = logging.getLogger(__name__)

RELEVANT_SCHEMA_LIMIT = 3


async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]):
episode_uuids = [episode.uuid for episode in episodes]
records, _, _ = await driver.execute_query(
"""
MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids
RETURN DISTINCT
n.uuid As uuid,
n.name AS name,
n.created_at AS created_at,
n.summary AS summary
""",
uuids=episode_uuids,
)

nodes: list[EntityNode] = []

for record in records:
nodes.append(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
)

return nodes


async def bfs(node_ids: list[str], driver: AsyncDriver):
records, _, _ = await driver.execute_query(
"""
MATCH (n WHERE n.uuid in $node_ids)-[r]->(m)
RETURN
RETURN DISTINCT
n.uuid AS source_node_uuid,
n.name AS source_name,
n.summary AS source_summary,
Expand Down Expand Up @@ -138,7 +169,7 @@ async def entity_similarity_search(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
Expand All @@ -155,7 +186,7 @@ async def entity_fulltext_search(
records, _, _ = await driver.execute_query(
"""
CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score
RETURN
RETURN
node.uuid As uuid,
node.name AS name,
node.created_at AS created_at,
Expand All @@ -173,7 +204,7 @@ async def entity_fulltext_search(
EntityNode(
uuid=record["uuid"],
name=record["name"],
labels=[],
labels=["Entity"],
created_at=datetime.now(),
summary=record["summary"],
)
Expand All @@ -193,7 +224,7 @@ async def edge_fulltext_search(
CALL db.index.fulltext.queryRelationships("name_and_fact", $query)
YIELD relationship AS r, score
MATCH (n:Entity)-[r]->(m:Entity)
RETURN
RETURN
r.uuid AS uuid,
n.uuid AS source_node_uuid,
m.uuid AS target_node_uuid,
Expand Down Expand Up @@ -291,3 +322,18 @@ async def get_relevant_edges(
)

return relevant_edges


# takes in a list of rankings of uuids
def rrf(results: list[list[str]], rank_const=1) -> list[str]:
scores: dict[str, int] = defaultdict(int)
for result in results:
for i, uuid in enumerate(result):
scores[uuid] += 1 / (i + rank_const)

scored_uuids = [term for term in scores.items()]
scored_uuids.sort(key=lambda term: term[1])

sorted_uuids = [term[0] for term in scored_uuids]

return sorted_uuids
Loading
Loading