Skip to content

Commit

Permalink
Speed up add episode (#77)
Browse files Browse the repository at this point in the history
* WIP

* updates

* use uuid for node dedupe

* pret-testing

* parallelized node resolution

* working add_episode

* revert to 4o

* format

* mypy update

* update types
  • Loading branch information
prasmussen15 authored Sep 3, 2024
1 parent db12ac5 commit e9e6039
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 177 deletions.
5 changes: 3 additions & 2 deletions examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ async def main(use_bulk: bool = True):
reference_time=message.actual_timestamp,
source_description='Podcast Transcript',
)
return

episodes: list[RawEpisode] = [
RawEpisode(
Expand All @@ -79,10 +80,10 @@ async def main(use_bulk: bool = True):
source_description='Podcast Transcript',
reference_time=message.actual_timestamp,
)
for i, message in enumerate(messages[3:14])
for i, message in enumerate(messages[3:20])
]

await client.add_episode_bulk(episodes)


asyncio.run(main(True))
asyncio.run(main(False))
173 changes: 94 additions & 79 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,17 @@
retrieve_previous_episodes_bulk,
)
from graphiti_core.utils.maintenance.edge_operations import (
dedupe_extracted_edges,
extract_edges,
resolve_extracted_edges,
)
from graphiti_core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
)
from graphiti_core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from graphiti_core.utils.maintenance.node_operations import (
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
invalidate_edges,
Expand Down Expand Up @@ -177,9 +180,9 @@ async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand Down Expand Up @@ -207,14 +210,14 @@ async def retrieve_episodes(
return await retrieve_episodes(self.driver, reference_time, last_n)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""
Process an episode and update the graph.
Expand Down Expand Up @@ -265,7 +268,6 @@ async def add_episode_endpoint(episode_data: EpisodeData):

nodes: list[EntityNode] = []
entity_edges: list[EntityEdge] = []
episodic_edges: list[EpisodicEdge] = []
embedder = self.llm_client.get_embedder()
now = datetime.now()

Expand All @@ -280,6 +282,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):
valid_at=reference_time,
)

# Extract entities as nodes

extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

Expand All @@ -288,57 +292,82 @@ async def add_episode_endpoint(episode_data: EpisodeData):
await asyncio.gather(
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)

# Resolve extracted nodes with nodes already in the graph
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
)
)

logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes

mentioned_nodes, _ = await resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists
)
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
nodes.extend(touched_nodes)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)

# Extract facts as edges given entity nodes
extracted_edges = await extract_edges(
self.llm_client, episode, touched_nodes, previous_episodes
self.llm_client, episode, mentioned_nodes, previous_episodes
)

# calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])

existing_edges = await get_relevant_edges(extracted_edges, self.driver)
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
# Resolve extracted edges with edges already in the graph
existing_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
[edge],
self.driver,
RELEVANT_SCHEMA_LIMIT,
edge.source_node_uuid,
edge.target_node_uuid,
)
for edge in extracted_edges
]
)
)
logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
)
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')

deduped_edges = await dedupe_extracted_edges(
self.llm_client,
extracted_edges,
existing_edges,
deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, extracted_edges, existing_edges_list
)

edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
for edge in deduped_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)

for edge in deduped_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
edge.expired_at = now
for edge in existing_edges:
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
# Extract dates for the newly extracted edges
edge_dates = await asyncio.gather(
*[
extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
for edge in deduped_edges
]
)

for i, edge in enumerate(deduped_edges):
valid_at = edge_dates[i][0]
invalid_at = edge_dates[i][1]

edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at:
if edge.invalid_at is not None:
edge.expired_at = now

entity_edges.extend(deduped_edges)

existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
]

(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
Expand All @@ -361,30 +390,18 @@ async def add_episode_endpoint(episode_data: EpisodeData):
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')

edges_to_save = existing_edges + deduped_edges

entity_edges.extend(edges_to_save)

edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]

logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')
entity_edges.extend(existing_edges)

logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')

episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
involved_nodes,
episode,
now,
)
episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
episode,
now,
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built

logger.info(f'Built episodic edges: {episodic_edges}')

# Future optimization would be using batch operations to save nodes and edges
Expand All @@ -395,9 +412,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):

end = time()
logger.info(f'Completed add_episode in {(end - start) * 1000} ms')
# for node in nodes:
# if isinstance(node, EntityNode):
# await node.update_summary(self.driver)

if success_callback:
await success_callback(episode)
except Exception as e:
Expand All @@ -407,8 +422,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):
raise e

async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
self,
bulk_episodes: list[RawEpisode],
):
"""
Process multiple episodes in bulk and update the graph.
Expand Down Expand Up @@ -572,18 +587,18 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
return edges

async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)

async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand Down
54 changes: 46 additions & 8 deletions graphiti_core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
v3: PromptVersion
edge_list: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
v3: PromptFunction
edge_list: PromptFunction


Expand All @@ -41,17 +43,17 @@ def v1(context: dict[str, Any]) -> list[Message]:
Message(
role='user',
content=f"""
Given the following context, deduplicate facts from a list of new facts given a list of existing facts:
Given the following context, deduplicate facts from a list of new facts given a list of existing edges:
Existing Facts:
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Facts:
New Edges:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
If any facts in New Facts is a duplicate of a fact in Existing Facts,
do not return it in the list of unique facts.
If any edge in New Edges is a duplicate of an edge in Existing Edges, add their uuids to the output list.
When finding duplicates edges, synthesize their facts into a short new fact.
Guidelines:
1. identical or near identical facts are duplicates
Expand All @@ -60,9 +62,11 @@ def v1(context: dict[str, Any]) -> list[Message]:
Respond with a JSON object in the following format:
{{
"unique_facts": [
"duplicates": [
{{
"uuid": "unique identifier of the fact"
"uuid": "uuid of the new node like 5d643020624c42fa9de13f97b1b3fa39",
"duplicate_of": "uuid of the existing node",
"fact": "one sentence description of the fact"
}}
]
}}
Expand Down Expand Up @@ -113,6 +117,40 @@ def v2(context: dict[str, Any]) -> list[Message]:
]


def v3(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates edges from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Edge:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. If the New Edges represents the same factual information as any edge in Existing Edges, return 'is_duplicate: true' in the
response. Otherwise, return 'is_duplicate: false'
2. If is_duplicate is true, also return the uuid of the existing edge in the response
Guidelines:
1. The facts do not need to be completely identical to be duplicates, they just need to express the same information.
Respond with a JSON object in the following format:
{{
"is_duplicate": true or false,
"uuid": uuid of the existing edge like "5d643020624c42fa9de13f97b1b3fa39" or null,
}}
""",
),
]


def edge_list(context: dict[str, Any]) -> list[Message]:
return [
Message(
Expand Down Expand Up @@ -151,4 +189,4 @@ def edge_list(context: dict[str, Any]) -> list[Message]:
]


versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}
versions: Versions = {'v1': v1, 'v2': v2, 'v3': v3, 'edge_list': edge_list}
Loading

0 comments on commit e9e6039

Please sign in to comment.