From f8643f14023a271933d6e3fa5a97c4fc89e92344 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Thu, 22 Aug 2024 11:26:21 -0400 Subject: [PATCH] search updates --- core/graphiti.py | 109 ++++-------------- core/prompts/dedupe_nodes.py | 2 +- core/{utils => }/search/__init__.py | 0 core/search/search.py | 77 +++++++++++++ core/{utils => }/search/search_utils.py | 62 ++++++++-- core/utils/bulk_utils.py | 3 +- .../maintenance/graph_data_operations.py | 62 +++++++++- podcast_runner.py | 2 +- tests/graphiti_tests_int.py | 9 +- 9 files changed, 221 insertions(+), 105 deletions(-) rename core/{utils => }/search/__init__.py (100%) create mode 100644 core/search/search.py rename core/{utils => }/search/search_utils.py (84%) diff --git a/core/graphiti.py b/core/graphiti.py index 14f5130b..e9967159 100644 --- a/core/graphiti.py +++ b/core/graphiti.py @@ -1,15 +1,15 @@ import asyncio from datetime import datetime import logging -from typing import Callable, LiteralString +from typing import Callable from neo4j import AsyncGraphDatabase from dotenv import load_dotenv from time import time import os -from core.llm_client.config import EMBEDDING_DIM -from core.nodes import EntityNode, EpisodicNode, Node -from core.edges import EntityEdge, Edge, EpisodicEdge +from core.nodes import EntityNode, EpisodicNode +from core.edges import EntityEdge, EpisodicEdge +from core.search.search import search, SearchConfig from core.utils import ( build_episodic_edges, retrieve_episodes, @@ -19,22 +19,21 @@ BulkEpisode, extract_nodes_and_edges_bulk, retrieve_previous_episodes_bulk, - compress_nodes, dedupe_nodes_bulk, resolve_edge_pointers, dedupe_edges_bulk, ) from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges -from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN +from core.utils.maintenance.graph_data_operations import ( + EPISODE_WINDOW_LEN, + build_indices_and_constraints, +) from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes from core.utils.maintenance.temporal_operations import ( invalidate_edges, prepare_edges_for_invalidation, ) -from core.utils.search.search_utils import ( - edge_similarity_search, - entity_fulltext_search, - bfs, +from core.search.search_utils import ( get_relevant_nodes, get_relevant_edges, ) @@ -64,10 +63,13 @@ def __init__( def close(self): self.driver.close() + async def build_indices_and_constraints(self): + await build_indices_and_constraints(self.driver) + async def retrieve_episodes( self, reference_time: datetime, - last_n: int, + last_n: int = EPISODE_WINDOW_LEN, sources: list[str] | None = "messages", ) -> list[EpisodicNode]: """Retrieve the last n episodic nodes from the graph""" @@ -103,9 +105,7 @@ async def add_episode( embedder = self.llm_client.client.embeddings now = datetime.now() - previous_episodes = await self.retrieve_episodes( - reference_time, last_n=EPISODE_WINDOW_LEN - ) + previous_episodes = await self.retrieve_episodes(reference_time) episode = EpisodicNode( name=name, labels=[], @@ -220,80 +220,6 @@ async def add_episode( else: raise e - async def build_indices(self): - index_queries: list[LiteralString] = [ - "CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)", - "CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)", - "CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)", - "CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)", - "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", - "CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)", - "CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)", - "CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)", - "CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)", - "CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)", - "CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)", - "CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)", - "CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)", - "CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]", - "CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]", - """ - CREATE VECTOR INDEX fact_embedding IF NOT EXISTS - FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding) - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """, - """ - CREATE VECTOR INDEX name_embedding IF NOT EXISTS - FOR (n:Entity) ON (n.name_embedding) - OPTIONS {indexConfig: { - `vector.dimensions`: 1024, - `vector.similarity_function`: 'cosine' - }} - """, - """ - CREATE CONSTRAINT entity_name IF NOT EXISTS - FOR (n:Entity) REQUIRE n.name IS UNIQUE - """, - """ - CREATE CONSTRAINT edge_facts IF NOT EXISTS - FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE - """, - ] - - await asyncio.gather( - *[self.driver.execute_query(query) for query in index_queries] - ) - - async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]: - text = query.replace("\n", " ") - search_vector = ( - ( - await self.llm_client.client.embeddings.create( - input=[text], model="text-embedding-3-small" - ) - ) - .data[0] - .embedding[:EMBEDDING_DIM] - ) - - edges = await edge_similarity_search(search_vector, self.driver) - nodes = await entity_fulltext_search(query, self.driver) - - node_ids = [node.uuid for node in nodes] - - for edge in edges: - node_ids.append(edge.source_node_uuid) - node_ids.append(edge.target_node_uuid) - - node_ids = list(dict.fromkeys(node_ids)) - - context = await bfs(node_ids, self.driver) - - return context - async def add_episode_bulk( self, bulk_episodes: list[BulkEpisode], @@ -368,3 +294,10 @@ async def add_episode_bulk( except Exception as e: raise e + + async def search( + self, query: str, timestamp: datetime, config: SearchConfig + ) -> list[tuple[EntityNode, list[EntityEdge]]]: + return await search( + self.driver, self.llm_client.client.embeddings, query, timestamp, config + ) diff --git a/core/prompts/dedupe_nodes.py b/core/prompts/dedupe_nodes.py index 3f54ef39..f1c89209 100644 --- a/core/prompts/dedupe_nodes.py +++ b/core/prompts/dedupe_nodes.py @@ -112,7 +112,7 @@ def node_list(context: dict[str, any]) -> list[Message]: Task: 1. Group nodes together such that all duplicate nodes are in the same list of names - 2. All dupolicate names should be grouped together in the same list + 2. All duplicate names should be grouped together in the same list Guidelines: 1. Each name from the list of nodes should appear EXACTLY once in your response diff --git a/core/utils/search/__init__.py b/core/search/__init__.py similarity index 100% rename from core/utils/search/__init__.py rename to core/search/__init__.py diff --git a/core/search/search.py b/core/search/search.py new file mode 100644 index 00000000..c272010d --- /dev/null +++ b/core/search/search.py @@ -0,0 +1,77 @@ +import asyncio +import logging +from datetime import datetime + +from neo4j import AsyncDriver +from pydantic import BaseModel + +from core.edges import EntityEdge +from core.llm_client.config import EMBEDDING_DIM +from core.search.search_utils import ( + edge_similarity_search, + edge_fulltext_search, + get_mentioned_nodes, + rrf, +) +from core.utils import retrieve_episodes +from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN + +logger = logging.getLogger(__name__) + + +class SearchConfig(BaseModel): + num_results: int = 10 + num_episodes: int = EPISODE_WINDOW_LEN + similarity_search: str = "cosine" + text_search: str = "BM25" + reranker: str = "rrf" + + +async def search( + driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig +): + episodes = [] + nodes = [] + edges = [] + + search_results = [] + + if config.num_episodes > 0: + episodes.extend(await retrieve_episodes(driver, timestamp)) + nodes.extend(await get_mentioned_nodes(driver, episodes)) + + if config.text_search: + text_search = await edge_fulltext_search(query, driver) + search_results.append(text_search) + + if config.similarity_search: + query_text = query.replace("\n", " ") + search_vector = ( + (await embedder.create(input=[query_text], model="text-embedding-3-small")) + .data[0] + .embedding[:EMBEDDING_DIM] + ) + + similarity_search = await edge_similarity_search(search_vector, driver) + search_results.append(similarity_search) + + if len(search_results) == 1: + edges = search_results[0] + + elif len(search_results) > 1 and not config.reranker: + logger.exception("Multiple searches enabled without a reranker") + raise Exception("Multiple searches enabled without a reranker") + + elif config.reranker: + search_result_uuids = [ + [edge.uuid for edge in result] for result in search_results + ] + edges.extend(rrf(search_result_uuids)) + + context = { + "episodes": episodes, + "nodes": nodes, + "edges": edges, + } + + return context diff --git a/core/utils/search/search_utils.py b/core/search/search_utils.py similarity index 84% rename from core/utils/search/search_utils.py rename to core/search/search_utils.py index 110b7a21..174e2d7e 100644 --- a/core/utils/search/search_utils.py +++ b/core/search/search_utils.py @@ -1,23 +1,54 @@ import asyncio import logging +from collections import defaultdict from datetime import datetime from time import time from neo4j import AsyncDriver from core.edges import EntityEdge -from core.nodes import EntityNode +from core.nodes import EntityNode, EpisodicNode logger = logging.getLogger(__name__) RELEVANT_SCHEMA_LIMIT = 3 +async def get_mentioned_nodes(driver: AsyncDriver, episodes: list[EpisodicNode]): + episode_uuids = [episode.uuid for episode in episodes] + records, _, _ = await driver.execute_query( + """ + MATCH (episode:Episodic)-[:MENTIONS]->(n:Entity) WHERE episode.uuid IN $uuids + RETURN DISTINCT + n.uuid As uuid, + n.name AS name, + n.created_at AS created_at, + n.summary AS summary + """, + uuids=episode_uuids, + ) + + nodes: list[EntityNode] = [] + + for record in records: + nodes.append( + EntityNode( + uuid=record["uuid"], + name=record["name"], + labels=["Entity"], + created_at=datetime.now(), + summary=record["summary"], + ) + ) + + return nodes + + async def bfs(node_ids: list[str], driver: AsyncDriver): records, _, _ = await driver.execute_query( """ MATCH (n WHERE n.uuid in $node_ids)-[r]->(m) - RETURN + RETURN DISTINCT n.uuid AS source_node_uuid, n.name AS source_name, n.summary AS source_summary, @@ -71,7 +102,7 @@ async def edge_similarity_search( CALL db.index.vector.queryRelationships("fact_embedding", 5, $search_vector) YIELD relationship AS r, score MATCH (n)-[r:RELATES_TO]->(m) - RETURN + RETURN DISTINCT r.uuid AS uuid, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, @@ -121,7 +152,7 @@ async def entity_similarity_search( """ CALL db.index.vector.queryNodes("name_embedding", $limit, $search_vector) YIELD node AS n, score - RETURN + RETURN DISTINCT n.uuid As uuid, n.name AS name, n.created_at AS created_at, @@ -138,7 +169,7 @@ async def entity_similarity_search( EntityNode( uuid=record["uuid"], name=record["name"], - labels=[], + labels=["Entity"], created_at=datetime.now(), summary=record["summary"], ) @@ -155,7 +186,7 @@ async def entity_fulltext_search( records, _, _ = await driver.execute_query( """ CALL db.index.fulltext.queryNodes("name_and_summary", $query) YIELD node, score - RETURN + RETURN DISTINCT node.uuid As uuid, node.name AS name, node.created_at AS created_at, @@ -173,7 +204,7 @@ async def entity_fulltext_search( EntityNode( uuid=record["uuid"], name=record["name"], - labels=[], + labels=["Entity"], created_at=datetime.now(), summary=record["summary"], ) @@ -193,7 +224,7 @@ async def edge_fulltext_search( CALL db.index.fulltext.queryRelationships("name_and_fact", $query) YIELD relationship AS r, score MATCH (n:Entity)-[r]->(m:Entity) - RETURN + RETURN DISTINCT r.uuid AS uuid, n.uuid AS source_node_uuid, m.uuid AS target_node_uuid, @@ -291,3 +322,18 @@ async def get_relevant_edges( ) return relevant_edges + + +# takes in a list of rankings of uuids +def rrf(results: list[list[str]], rank_const=1) -> list[str]: + scores: dict[str, int] = defaultdict(int) + for result in results: + for i, uuid in enumerate(result): + scores[uuid] += 1 / i + rank_const + + scored_uuids = [term for term in scores.items()] + scored_uuids.sort(key=lambda term: term[1]) + + sorted_uuids = [term[0] for term in scored_uuids] + + return sorted_uuids diff --git a/core/utils/bulk_utils.py b/core/utils/bulk_utils.py index a5b361ef..7004fe5d 100644 --- a/core/utils/bulk_utils.py +++ b/core/utils/bulk_utils.py @@ -1,5 +1,4 @@ import asyncio -from collections import defaultdict from datetime import datetime from neo4j import AsyncDriver @@ -21,7 +20,7 @@ dedupe_node_list, dedupe_extracted_nodes, ) -from core.utils.search.search_utils import get_relevant_nodes, get_relevant_edges +from core.search.search_utils import get_relevant_nodes, get_relevant_edges CHUNK_SIZE = 10 diff --git a/core/utils/maintenance/graph_data_operations.py b/core/utils/maintenance/graph_data_operations.py index 790400f3..221c9b44 100644 --- a/core/utils/maintenance/graph_data_operations.py +++ b/core/utils/maintenance/graph_data_operations.py @@ -1,4 +1,6 @@ +import asyncio from datetime import datetime, timezone +from typing import LiteralString from core.nodes import EpisodicNode from neo4j import AsyncDriver @@ -9,6 +11,64 @@ logger = logging.getLogger(__name__) +async def build_indices_and_constraints(driver: AsyncDriver): + constraints: list[LiteralString] = [ + """ + CREATE CONSTRAINT entity_name IF NOT EXISTS + FOR (n:Entity) REQUIRE n.name IS UNIQUE + """, + """ + CREATE CONSTRAINT edge_facts IF NOT EXISTS + FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE + """, + ] + + range_indices: list[LiteralString] = [ + "CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)", + "CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)", + "CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)", + "CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)", + "CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)", + "CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)", + "CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)", + "CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)", + "CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)", + "CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)", + "CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)", + "CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)", + "CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)", + ] + + fulltext_indices: list[LiteralString] = [ + "CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]", + "CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]", + ] + + vector_indices: list[LiteralString] = [ + """ + CREATE VECTOR INDEX fact_embedding IF NOT EXISTS + FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding) + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """, + """ + CREATE VECTOR INDEX name_embedding IF NOT EXISTS + FOR (n:Entity) ON (n.name_embedding) + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """, + ] + index_queries: list[LiteralString] = ( + constraints + range_indices + fulltext_indices + vector_indices + ) + + await asyncio.gather(*[driver.execute_query(query) for query in index_queries]) + + async def clear_data(driver: AsyncDriver): async with driver.session() as session: @@ -21,7 +81,7 @@ async def delete_all(tx): async def retrieve_episodes( driver: AsyncDriver, reference_time: datetime, - last_n: int, + last_n: int = EPISODE_WINDOW_LEN, sources: list[str] | None = "messages", ) -> list[EpisodicNode]: """Retrieve the last n episodic nodes from the graph""" diff --git a/podcast_runner.py b/podcast_runner.py index 2fff8285..7953e5a0 100644 --- a/podcast_runner.py +++ b/podcast_runner.py @@ -61,7 +61,7 @@ async def main(use_bulk: bool = True): episode_type="string", reference_time=message.actual_timestamp, ) - for i, message in enumerate(messages[3:7]) + for i, message in enumerate(messages[3:14]) ] await client.add_episode_bulk(episodes) diff --git a/tests/graphiti_tests_int.py b/tests/graphiti_tests_int.py index 310586b2..f40438d9 100644 --- a/tests/graphiti_tests_int.py +++ b/tests/graphiti_tests_int.py @@ -68,19 +68,20 @@ def format_context(context): async def test_graphiti_init(): logger = setup_logging() graphiti = Graphiti(NEO4J_URI, NEO4j_USER, NEO4j_PASSWORD, None) - await graphiti.build_indices() context = await graphiti.search("Freakenomics guest") - logger.info("QUERY: Freakenomics guest" + "RESULT:" + format_context(context)) + logger.info("\nQUERY: Freakenomics guest" + "\nRESULT:\n" + format_context(context)) context = await graphiti.search("tania tetlow") - logger.info("QUERY: Tania Tetlow" + "RESULT:" + format_context(context)) + logger.info("\nQUERY: Tania Tetlow" + "\nRESULT:\n" + format_context(context)) context = await graphiti.search("issues with higher ed") - logger.info("QUERY: issues with higher ed" + "RESULT:" + format_context(context)) + logger.info( + "\nQUERY: issues with higher ed" + "\nRESULT:\n" + format_context(context) + ) graphiti.close()