Skip to content

Commit

Permalink
add bulk temporal extraction and improve bulk quality and performance (
Browse files Browse the repository at this point in the history
…#67)

* parallelize edge deduping more

* parallelize node insertion more

* improve bulk behavior performance

* dedupe nodes actually works

* add a reranker to search

* bulk dedupe episodes only across the same nodes

* add temporal extraction bulk function

* cleaned up bulk

* default to 4o

* format

* mypy

* mympy

* mypy ignore
  • Loading branch information
prasmussen15 authored Aug 30, 2024
1 parent aac06d9 commit 35a4e51
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 61 deletions.
21 changes: 13 additions & 8 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from graphiti_core.nodes import EntityNode, EpisodeType, EpisodicNode
from graphiti_core.search.search import Reranker, SearchConfig, SearchMethod, hybrid_search
from graphiti_core.search.search_utils import (
RELEVANT_SCHEMA_LIMIT,
get_relevant_edges,
get_relevant_nodes,
hybrid_node_search,
Expand All @@ -41,6 +42,7 @@
RawEpisode,
dedupe_edges_bulk,
dedupe_nodes_bulk,
extract_edge_dates_bulk,
extract_nodes_and_edges_bulk,
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
Expand Down Expand Up @@ -319,26 +321,24 @@ async def add_episode_endpoint(episode_data: EpisodeData):
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode.valid_at,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
edge.expired_at = now
for edge in existing_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode.valid_at,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = datetime.now()
edge.expired_at = now
(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
Expand Down Expand Up @@ -481,15 +481,18 @@ async def add_episode_bulk(
*[edge.generate_embedding(embedder) for edge in extracted_edges],
)

# Dedupe extracted nodes
nodes, uuid_map = await dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes)
# Dedupe extracted nodes, compress extracted edges
(nodes, uuid_map), extracted_edges_timestamped = await asyncio.gather(
dedupe_nodes_bulk(self.driver, self.llm_client, extracted_nodes),
extract_edge_dates_bulk(self.llm_client, extracted_edges, episode_pairs),
)

# save nodes to KG
await asyncio.gather(*[node.save(self.driver) for node in nodes])

# re-map edge pointers so that they don't point to discard dupe nodes
extracted_edges_with_resolved_pointers: list[EntityEdge] = resolve_edge_pointers(
extracted_edges, uuid_map
extracted_edges_timestamped, uuid_map
)
episodic_edges_with_resolved_pointers: list[EpisodicEdge] = resolve_edge_pointers(
episodic_edges, uuid_map
Expand Down Expand Up @@ -579,7 +582,9 @@ async def _search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)

async def get_nodes_by_query(self, query: str, limit: int | None = None) -> list[EntityNode]:
async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand Down
12 changes: 10 additions & 2 deletions graphiti_core/prompts/dedupe_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def v1(context: dict[str, Any]) -> list[Message]:
1. start with the list of nodes from New Nodes
2. If any node in New Nodes is a duplicate of a node in Existing Nodes, replace the new node with the existing
node in the list
3. Respond with the resulting list of nodes
3. when deduplicating nodes, synthesize their summaries into a short new summary that contains the relevant information
of the summaries of the new and existing nodes
4. Respond with the resulting list of nodes
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
Expand All @@ -64,6 +66,7 @@ def v1(context: dict[str, Any]) -> list[Message]:
"new_nodes": [
{{
"name": "Unique identifier for the node",
"summary": "Brief summary of the node's role or significance"
}}
]
}}
Expand Down Expand Up @@ -92,6 +95,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
If a node in the new nodes is describing the same entity as a node in the existing nodes, mark it as a duplicate!!!
Task:
If any node in New Nodes is a duplicate of a node in Existing Nodes, add their names to the output list
When finding duplicates nodes, synthesize their summaries into a short new summary that contains the
relevant information of the summaries of the new and existing nodes.
Guidelines:
1. Use both the name and summary of nodes to determine if they are duplicates,
Expand All @@ -104,7 +109,8 @@ def v2(context: dict[str, Any]) -> list[Message]:
"duplicates": [
{{
"name": "name of the new node",
"duplicate_of": "name of the existing node"
"duplicate_of": "name of the existing node",
"summary": "Brief summary of the node's role or significance. Takes information from the new and existing nodes"
}}
]
}}
Expand All @@ -130,6 +136,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
Task:
1. Group nodes together such that all duplicate nodes are in the same list of names
2. All duplicate names should be grouped together in the same list
3. Also return a new summary that synthesizes the summary into a new short summary
Guidelines:
1. Each name from the list of nodes should appear EXACTLY once in your response
Expand All @@ -140,6 +147,7 @@ def node_list(context: dict[str, Any]) -> list[Message]:
"nodes": [
{{
"names": ["myNode", "node that is a duplicate of myNode"],
"summary": "Brief summary of the node summaries that appear in the list of names."
}}
]
}}
Expand Down
3 changes: 2 additions & 1 deletion graphiti_core/prompts/extract_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,11 @@ def v2(context: dict[str, Any]) -> list[Message]:
Guidelines:
1. Create edges only between the provided nodes.
2. Each edge should represent a clear relationship between two nodes.
2. Each edge should represent a clear relationship between two DISTINCT nodes.
3. The relation_type should be a concise, all-caps description of the relationship (e.g., LOVES, IS_FRIENDS_WITH, WORKS_FOR).
4. Provide a more detailed fact describing the relationship.
5. Consider temporal aspects of relationships when relevant.
6. Avoid using the same node as the source and target of a relationship
Respond with a JSON object in the following format:
{{
Expand Down
12 changes: 6 additions & 6 deletions graphiti_core/search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ class SearchResults(BaseModel):


async def hybrid_search(
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
driver: AsyncDriver,
embedder,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
) -> SearchResults:
start = time()

Expand Down
32 changes: 15 additions & 17 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,13 @@ async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int | None = None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Perform a hybrid search for nodes using both text queries and embeddings.
This method combines fulltext search and vector similarity search to find
relevant nodes in the graph database.
relevant nodes in the graph database. It uses an rrf reranker.
Parameters
----------
Expand Down Expand Up @@ -307,27 +307,25 @@ async def hybrid_node_search(
"""

start = time()
relevant_nodes: list[EntityNode] = []
relevant_node_uuids = set()

results = await asyncio.gather(
*[entity_fulltext_search(q, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT)) for q in queries],
*[
entity_similarity_search(e, driver, 2 * (limit or RELEVANT_SCHEMA_LIMIT))
for e in embeddings
],
results: list[list[EntityNode]] = list(
await asyncio.gather(
*[entity_fulltext_search(q, driver, 2 * limit) for q in queries],
*[entity_similarity_search(e, driver, 2 * limit) for e in embeddings],
)
)

for result in results:
for node in result:
if node.uuid in relevant_node_uuids:
continue
node_uuid_map: dict[str, EntityNode] = {
node.uuid: node for result in results for node in result
}
result_uuids = [[node.uuid for node in result] for result in results]

ranked_uuids = rrf(result_uuids)

relevant_node_uuids.add(node.uuid)
relevant_nodes.append(node)
relevant_nodes: list[EntityNode] = [node_uuid_map[uuid] for uuid in ranked_uuids]

end = time()
logger.info(f'Found relevant nodes: {relevant_node_uuids} in {(end - start) * 1000} ms')
logger.info(f'Found relevant nodes: {ranked_uuids} in {(end - start) * 1000} ms')
return relevant_nodes


Expand Down
Loading

0 comments on commit 35a4e51

Please sign in to comment.