Skip to content

Commit

Permalink
search updates (#14)
Browse files Browse the repository at this point in the history
* search updates

* test updates

* add opinionated search

* update
  • Loading branch information
prasmussen15 committed Aug 22, 2024
1 parent 8141a78 commit 63b9790
Show file tree
Hide file tree
Showing 9 changed files with 269 additions and 115 deletions.
123 changes: 35 additions & 88 deletions core/graphiti.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
from datetime import datetime
import logging
from typing import Callable, LiteralString
from typing import Callable
from neo4j import AsyncGraphDatabase
from dotenv import load_dotenv
from time import time
import os

from core.llm_client.config import EMBEDDING_DIM
from core.nodes import EntityNode, EpisodicNode, Node
from core.edges import EntityEdge, Edge, EpisodicEdge
from core.nodes import EntityNode, EpisodicNode
from core.edges import EntityEdge, EpisodicEdge
from core.search.search import SearchConfig, hybrid_search
from core.utils import (
build_episodic_edges,
retrieve_episodes,
Expand All @@ -19,22 +19,21 @@
BulkEpisode,
extract_nodes_and_edges_bulk,
retrieve_previous_episodes_bulk,
compress_nodes,
dedupe_nodes_bulk,
resolve_edge_pointers,
dedupe_edges_bulk,
)
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
)
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
invalidate_edges,
prepare_edges_for_invalidation,
)
from core.utils.search.search_utils import (
edge_similarity_search,
entity_fulltext_search,
bfs,
from core.search.search_utils import (
get_relevant_nodes,
get_relevant_edges,
)
Expand Down Expand Up @@ -64,10 +63,13 @@ def __init__(
def close(self):
self.driver.close()

async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int,
last_n: int = EPISODE_WINDOW_LEN,
sources: list[str] | None = "messages",
) -> list[EpisodicNode]:
"""Retrieve the last n episodic nodes from the graph"""
Expand Down Expand Up @@ -103,9 +105,7 @@ async def add_episode(
embedder = self.llm_client.client.embeddings
now = datetime.now()

previous_episodes = await self.retrieve_episodes(
reference_time, last_n=EPISODE_WINDOW_LEN
)
previous_episodes = await self.retrieve_episodes(reference_time)
episode = EpisodicNode(
name=name,
labels=[],
Expand Down Expand Up @@ -220,80 +220,6 @@ async def add_episode(
else:
raise e

async def build_indices(self):
index_queries: list[LiteralString] = [
"CREATE INDEX entity_uuid IF NOT EXISTS FOR (n:Entity) ON (n.uuid)",
"CREATE INDEX episode_uuid IF NOT EXISTS FOR (n:Episodic) ON (n.uuid)",
"CREATE INDEX relation_uuid IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.uuid)",
"CREATE INDEX mention_uuid IF NOT EXISTS FOR ()-[e:MENTIONS]-() ON (e.uuid)",
"CREATE INDEX name_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.name)",
"CREATE INDEX created_at_entity_index IF NOT EXISTS FOR (n:Entity) ON (n.created_at)",
"CREATE INDEX created_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.created_at)",
"CREATE INDEX valid_at_episodic_index IF NOT EXISTS FOR (n:Episodic) ON (n.valid_at)",
"CREATE INDEX name_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.name)",
"CREATE INDEX created_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.created_at)",
"CREATE INDEX expired_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.expired_at)",
"CREATE INDEX valid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.valid_at)",
"CREATE INDEX invalid_at_edge_index IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON (e.invalid_at)",
"CREATE FULLTEXT INDEX name_and_summary IF NOT EXISTS FOR (n:Entity) ON EACH [n.name, n.summary]",
"CREATE FULLTEXT INDEX name_and_fact IF NOT EXISTS FOR ()-[e:RELATES_TO]-() ON EACH [e.name, e.fact]",
"""
CREATE VECTOR INDEX fact_embedding IF NOT EXISTS
FOR ()-[r:RELATES_TO]-() ON (r.fact_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE VECTOR INDEX name_embedding IF NOT EXISTS
FOR (n:Entity) ON (n.name_embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""",
"""
CREATE CONSTRAINT entity_name IF NOT EXISTS
FOR (n:Entity) REQUIRE n.name IS UNIQUE
""",
"""
CREATE CONSTRAINT edge_facts IF NOT EXISTS
FOR ()-[e:RELATES_TO]-() REQUIRE e.fact IS UNIQUE
""",
]

await asyncio.gather(
*[self.driver.execute_query(query) for query in index_queries]
)

async def search(self, query: str) -> list[tuple[EntityNode, list[EntityEdge]]]:
text = query.replace("\n", " ")
search_vector = (
(
await self.llm_client.client.embeddings.create(
input=[text], model="text-embedding-3-small"
)
)
.data[0]
.embedding[:EMBEDDING_DIM]
)

edges = await edge_similarity_search(search_vector, self.driver)
nodes = await entity_fulltext_search(query, self.driver)

node_ids = [node.uuid for node in nodes]

for edge in edges:
node_ids.append(edge.source_node_uuid)
node_ids.append(edge.target_node_uuid)

node_ids = list(dict.fromkeys(node_ids))

context = await bfs(node_ids, self.driver)

return context

async def add_episode_bulk(
self,
bulk_episodes: list[BulkEpisode],
Expand Down Expand Up @@ -368,3 +294,24 @@ async def add_episode_bulk(

except Exception as e:
raise e

async def search(self, query: str, num_results=10):
search_config = SearchConfig(num_episodes=0, num_results=num_results)
edges = (
await hybrid_search(
self.driver,
self.llm_client.client.embeddings,
query,
datetime.now(),
search_config,
)
)["edges"]

facts = [edge.fact for edge in edges]

return facts

async def _search(self, query: str, timestamp: datetime, config: SearchConfig):
return await hybrid_search(
self.driver, self.llm_client.client.embeddings, query, timestamp, config
)
2 changes: 1 addition & 1 deletion core/prompts/dedupe_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def node_list(context: dict[str, any]) -> list[Message]:
Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All dupolicate names should be grouped together in the same list
2. All duplicate names should be grouped together in the same list
Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response
Expand Down
File renamed without changes.
104 changes: 104 additions & 0 deletions core/search/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import asyncio
import logging
from datetime import datetime
from time import time

from neo4j import AsyncDriver
from pydantic import BaseModel

from core.edges import EntityEdge, Edge
from core.llm_client.config import EMBEDDING_DIM
from core.nodes import Node
from core.search.search_utils import (
edge_similarity_search,
edge_fulltext_search,
get_mentioned_nodes,
rrf,
)
from core.utils import retrieve_episodes
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN

logger = logging.getLogger(__name__)


class SearchConfig(BaseModel):
num_results: int = 10
num_episodes: int = EPISODE_WINDOW_LEN
similarity_search: str = "cosine"
text_search: str = "BM25"
reranker: str = "rrf"


async def hybrid_search(
driver: AsyncDriver, embedder, query: str, timestamp: datetime, config: SearchConfig
) -> dict[str, [Node | Edge]]:
start = time()

episodes = []
nodes = []
edges = []

search_results = []

if config.num_episodes > 0:
episodes.extend(await retrieve_episodes(driver, timestamp))
nodes.extend(await get_mentioned_nodes(driver, episodes))

if config.text_search == "BM25":
text_search = await edge_fulltext_search(query, driver)
search_results.append(text_search)

if config.similarity_search == "cosine":
query_text = query.replace("\n", " ")
search_vector = (
(await embedder.create(input=[query_text], model="text-embedding-3-small"))
.data[0]
.embedding[:EMBEDDING_DIM]
)

similarity_search = await edge_similarity_search(search_vector, driver)
search_results.append(similarity_search)

if len(search_results) == 1:
edges = search_results[0]

elif len(search_results) > 1 and not config.reranker == "rrf":
logger.exception("Multiple searches enabled without a reranker")
raise Exception("Multiple searches enabled without a reranker")

elif config.reranker == "rrf":
edge_uuid_map = {}
search_result_uuids = []

logger.info([[edge.fact for edge in result] for result in search_results])

for result in search_results:
result_uuids = []
for edge in result:
result_uuids.append(edge.uuid)
edge_uuid_map[edge.uuid] = edge

search_result_uuids.append(result_uuids)

search_result_uuids = [
[edge.uuid for edge in result] for result in search_results
]

reranked_uuids = rrf(search_result_uuids)

reranked_edges = [edge_uuid_map[uuid] for uuid in reranked_uuids]
edges.extend(reranked_edges)

context = {
"episodes": episodes,
"nodes": nodes,
"edges": edges,
}

end = time()

logger.info(
f"search returned context for query {query} in {(end - start) * 1000} ms"
)

return context
Loading

0 comments on commit 63b9790

Please sign in to comment.