Skip to content

Commit

Permalink
Invalidation updates && improvements (#20)
Browse files Browse the repository at this point in the history
* wip

* wip

* wip

* fix: Linter errors

* fix formatting

* chore: fix ruff

* fix: Duplication

---------

Co-authored-by: Daniel Chalef <131175+danielchalef@users.noreply.github.com>
  • Loading branch information
paul-paliychuk and danielchalef authored Aug 22, 2024
1 parent 94873f1 commit 1f1652f
Show file tree
Hide file tree
Showing 9 changed files with 505 additions and 52 deletions.
65 changes: 43 additions & 22 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@
resolve_edge_pointers,
retrieve_previous_episodes_bulk,
)
from core.utils.maintenance.edge_operations import dedupe_extracted_edges, extract_edges
from core.utils.maintenance.edge_operations import (
dedupe_extracted_edges,
extract_edges,
)
from core.utils.maintenance.graph_data_operations import (
EPISODE_WINDOW_LEN,
build_indices_and_constraints,
Expand Down Expand Up @@ -116,6 +119,7 @@ async def add_episode(
)

extracted_nodes = await extract_nodes(self.llm_client, episode, previous_episodes)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

# Calculate Embeddings

Expand All @@ -124,14 +128,14 @@ async def add_episode(
)
existing_nodes = await get_relevant_nodes(extracted_nodes, self.driver)
logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')
new_nodes, _ = await dedupe_extracted_nodes(
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(f'Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}')
nodes.extend(new_nodes)
logger.info(f'Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}')
nodes.extend(touched_nodes)

extracted_edges = await extract_edges(
self.llm_client, episode, new_nodes, previous_episodes
self.llm_client, episode, touched_nodes, previous_episodes
)

await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
Expand All @@ -140,10 +144,23 @@ async def add_episode(
logger.info(f'Existing edges: {[(e.name, e.uuid) for e in existing_edges]}')
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')

# deduped_edges = await dedupe_extracted_edges_v2(
# self.llm_client,
# extract_node_and_edge_triplets(extracted_edges, nodes),
# extract_node_and_edge_triplets(existing_edges, nodes),
# )

deduped_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
self.llm_client,
extracted_edges,
existing_edges,
)

edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
for edge in deduped_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)

(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
Expand All @@ -155,39 +172,43 @@ async def add_episode(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
episode,
previous_episodes,
)

entity_edges.extend(invalidated_edges)
for edge in invalidated_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)

logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')
edges_to_save = invalidated_edges

logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
entity_edges.extend(deduped_edges)
# There may be an overlap between deduped and invalidated edges, so we want to make sure to save the invalidated one
for deduped_edge in deduped_edges:
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
edges_to_save.append(deduped_edge)

new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)
entity_edges.extend(edges_to_save)

edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [node for node in nodes if node.uuid in edge_touched_node_uuids]

logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in new_edges]}')
logger.info(f'Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}')

logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')

logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')

entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
nodes,
involved_nodes,
episode,
now,
)
)
# Important to append the episode to the nodes at the end so that self referencing episodic edges are not built
logger.info(f'Built episodic edges: {episodic_edges}')

# invalidated_edges = await self.invalidate_edges(
# episode, new_nodes, new_edges, relevant_schema, previous_episodes
# )

# edges.extend(invalidated_edges)

# Future optimization would be using batch operations to save nodes and edges
await episode.save(self.driver)
await asyncio.gather(*[node.save(self.driver) for node in nodes])
Expand Down
47 changes: 45 additions & 2 deletions core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

class Prompt(Protocol):
v1: PromptVersion
edge_list: PromptVersion
v2: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
edge_list: PromptFunction


Expand Down Expand Up @@ -54,6 +55,48 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role='system',
content='You are a helpful assistant that de-duplicates relationship from edge lists.',
),
Message(
role='user',
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
New Edges:
{json.dumps(context['extracted_edges'], indent=2)}
Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges
Guidelines:
1. Use both the triplet name and fact of edges to determine if they are duplicates,
duplicate edges may have different names meaning the same thing and slight variations in the facts.
2. If you encounter facts that are semantically equivalent or very similar, keep the original edge
Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"triplet": "source_node_name-edge_name-target_node_name",
"fact": "one sentence description of the fact"
}}
]
}}
""",
),
]


def edge_list(context: dict[str, any]) -> list[Message]:
return [
Message(
Expand Down Expand Up @@ -90,4 +133,4 @@ def edge_list(context: dict[str, any]) -> list[Message]:
]


versions: Versions = {'v1': v1, 'edge_list': edge_list}
versions: Versions = {'v1': v1, 'v2': v2, 'edge_list': edge_list}
18 changes: 12 additions & 6 deletions core/prompts/invalidate_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,40 @@ def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role='system',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
),
Message(
role='user',
content=f"""
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in new edges.
Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.
Previous Episodes:
{context['previous_episodes']}
Current Episode:
{context['current_episode']}
Existing Edges (sorted by timestamp, newest first):
{context['existing_edges']}
New Edges:
{context['new_edges']}
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), TIMESTAMP)"
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
"reason": "Brief explanation of why this edge is being invalidated"
"fact": "Updated fact of the edge"
}}
]
}}
If no relationships need to be invalidated, return an empty list for "invalidated_edges".
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]
Expand Down
2 changes: 1 addition & 1 deletion core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def dedupe_nodes_bulk(

existing_nodes = await get_relevant_nodes(compressed_nodes, driver)

nodes, partial_uuid_map = await dedupe_extracted_nodes(
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
)

Expand Down
46 changes: 46 additions & 0 deletions core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from core.llm_client import LLMClient
from core.nodes import EntityNode, EpisodicNode
from core.prompts import prompt_library
from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -179,6 +180,51 @@ async def extract_edges(
return edges


def create_edge_identifier(
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
) -> str:
return f'{source_node.name}-{edge.name}-{target_node.name}'


async def dedupe_extracted_edges_v2(
llm_client: LLMClient,
extracted_edges: list[NodeEdgeNodeTriplet],
existing_edges: list[NodeEdgeNodeTriplet],
) -> list[NodeEdgeNodeTriplet]:
# Create edge map
edge_map = {}
for n1, edge, n2 in existing_edges:
edge_map[create_edge_identifier(n1, edge, n2)] = edge
for n1, edge, n2 in extracted_edges:
if create_edge_identifier(n1, edge, n2) in edge_map:
continue
edge_map[create_edge_identifier(n1, edge, n2)] = edge

# Prepare context for LLM
context = {
'extracted_edges': [
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
for n1, edge, n2 in extracted_edges
],
'existing_edges': [
{'triplet': create_edge_identifier(n1, edge, n2), 'fact': edge.fact}
for n1, edge, n2 in extracted_edges
],
}
logger.info(prompt_library.dedupe_edges.v2(context))
llm_response = await llm_client.generate_response(prompt_library.dedupe_edges.v2(context))
new_edges_data = llm_response.get('new_edges', [])
logger.info(f'Extracted new edges: {new_edges_data}')

# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data['triplet']]
edges.append(edge)

return edges


async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
Expand Down
18 changes: 14 additions & 4 deletions core/utils/maintenance/node_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ async def dedupe_extracted_nodes(
for node in existing_nodes:
node_map[node.name] = node

# Temp hack
new_nodes_map = {}
for node in extracted_nodes:
new_nodes_map[node.name] = node

# Prepare context for LLM
existing_nodes_context = [
{'name': node.name, 'summary': node.summary} for node in existing_nodes
Expand All @@ -131,20 +136,25 @@ async def dedupe_extracted_nodes(

uuid_map = {}
for duplicate in duplicate_data:
uuid = node_map[duplicate['name']].uuid
uuid = new_nodes_map[duplicate['name']].uuid
uuid_value = node_map[duplicate['duplicate_of']].uuid
uuid_map[uuid] = uuid_value

nodes = []
brand_new_nodes = []
for node in extracted_nodes:
if node.uuid in uuid_map:
existing_name = uuid_map[node.name]
existing_node = node_map[existing_name]
existing_uuid = uuid_map[node.uuid]
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
existing_node = next((v for k, v in node_map.items() if v.uuid == existing_uuid), None)
nodes.append(existing_node)
continue
brand_new_nodes.append(node)
nodes.append(node)

return nodes, uuid_map
return nodes, uuid_map, brand_new_nodes


async def dedupe_node_list(
Expand Down
Loading

0 comments on commit 1f1652f

Please sign in to comment.