Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalidation updates && improvements #20

Merged
merged 11 commits into from
Aug 22, 2024
57 changes: 43 additions & 14 deletions core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@
resolve_edge_pointers,
dedupe_edges_bulk,
)
from core.utils.maintenance.edge_operations import extract_edges, dedupe_extracted_edges
from core.utils.maintenance.edge_operations import (
extract_edges,
dedupe_extracted_edges_v2,
dedupe_extracted_edges,
)
from core.utils.maintenance.graph_data_operations import EPISODE_WINDOW_LEN
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
invalidate_edges,
prepare_edges_for_invalidation,
extract_node_and_edge_triplets,
)
from core.utils.search.search_utils import (
edge_similarity_search,
Expand Down Expand Up @@ -119,6 +124,9 @@ async def add_episode(
extracted_nodes = await extract_nodes(
self.llm_client, episode, previous_episodes
)
logger.info(
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
)

# Calculate Embeddings

Expand All @@ -129,16 +137,16 @@ async def add_episode(
logger.info(
f"Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}"
)
new_nodes, _ = await dedupe_extracted_nodes(
touched_nodes, _, brand_new_nodes = await dedupe_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes
)
logger.info(
f"Deduped touched nodes: {[(n.name, n.uuid) for n in new_nodes]}"
f"Adjusted touched nodes: {[(n.name, n.uuid) for n in touched_nodes]}"
)
nodes.extend(new_nodes)
nodes.extend(touched_nodes)

extracted_edges = await extract_edges(
self.llm_client, episode, new_nodes, previous_episodes
self.llm_client, episode, touched_nodes, previous_episodes
)

await asyncio.gather(
Expand All @@ -151,10 +159,23 @@ async def add_episode(
f"Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}"
)

# deduped_edges = await dedupe_extracted_edges_v2(
# self.llm_client,
# extract_node_and_edge_triplets(extracted_edges, nodes),
# extract_node_and_edge_triplets(existing_edges, nodes),
# )

deduped_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
self.llm_client,
extracted_edges,
existing_edges,
)

edge_touched_node_uuids = [n.uuid for n in brand_new_nodes]
for edge in deduped_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)

(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
Expand All @@ -166,28 +187,36 @@ async def add_episode(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
episode,
previous_episodes,
)

for edge in invalidated_edges:
edge_touched_node_uuids.append(edge.source_node_uuid)
edge_touched_node_uuids.append(edge.target_node_uuid)

entity_edges.extend(invalidated_edges)

edge_touched_node_uuids = list(set(edge_touched_node_uuids))
involved_nodes = [
node for node in nodes if node.uuid in edge_touched_node_uuids
]

logger.info(
f"Edge touched nodes: {[(n.name, n.uuid) for n in involved_nodes]}"
)

logger.info(
f"Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}"
)

logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}")
entity_edges.extend(deduped_edges)

new_edges = await dedupe_extracted_edges(
self.llm_client, extracted_edges, existing_edges
)

logger.info(f"Deduped edges: {[(e.name, e.uuid) for e in new_edges]}")

entity_edges.extend(new_edges)
episodic_edges.extend(
build_episodic_edges(
# There may be an overlap between new_nodes and affected_nodes, so we're deduplicating them
nodes,
involved_nodes,
episode,
now,
)
Expand Down
46 changes: 45 additions & 1 deletion core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
edge_list: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
edge_list: PromptFunction


Expand Down Expand Up @@ -54,6 +56,48 @@ def v1(context: dict[str, any]) -> list[Message]:
]


def v2(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are a helpful assistant that de-duplicates relationship from edge lists.",
),
Message(
role="user",
content=f"""
Given the following context, deduplicate edges from a list of new edges given a list of existing edges:

Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}

New Edges:
{json.dumps(context['extracted_edges'], indent=2)}

Task:
1. start with the list of edges from New Edges
2. If any edge in New Edges is a duplicate of an edge in Existing Edges, replace the new edge with the existing
edge in the list
3. Respond with the resulting list of edges

Guidelines:
1. Use both the triplet name and fact of edges to determine if they are duplicates,
duplicate edges may have different names meaning the same thing and slight variations in the facts.
2. If you encounter facts that are semantically equivalent or very similar, keep the original edge

Respond with a JSON object in the following format:
{{
"new_edges": [
{{
"triplet": "source_node_name-edge_name-target_node_name",
"fact": "one sentence description of the fact"
}}
]
}}
""",
),
]


def edge_list(context: dict[str, any]) -> list[Message]:
return [
Message(
Expand Down Expand Up @@ -90,4 +134,4 @@ def edge_list(context: dict[str, any]) -> list[Message]:
]


versions: Versions = {"v1": v1, "edge_list": edge_list}
versions: Versions = {"v1": v1, "v2": v2, "edge_list": edge_list}
18 changes: 12 additions & 6 deletions core/prompts/invalidate_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,40 @@ def v1(context: dict[str, any]) -> list[Message]:
return [
Message(
role="system",
content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based on newer information.",
content="You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.",
),
Message(
role="user",
content=f"""
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Based on the provided existing edges and new edges with their timestamps, determine which existing relationships, if any, should be invalidated due to contradictions or updates in the new edges.
Only mark a relationship as invalid if there is clear evidence from new edges that the relationship is no longer true.
Do not invalidate relationships merely because they weren't mentioned in new edges.
Do not invalidate relationships merely because they weren't mentioned in new edges. You may use the current episode and previous episodes as well as the facts of each edge to understand the context of the relationships.

Previous Episodes:
{context['previous_episodes']}

Current Episode:
{context['current_episode']}

Existing Edges (sorted by timestamp, newest first):
{context['existing_edges']}

New Edges:
{context['new_edges']}

Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (TIMESTAMP)"
Each edge is formatted as: "UUID | SOURCE_NODE - EDGE_NAME - TARGET_NODE (fact: EDGE_FACT), TIMESTAMP)"

For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"edge_uuid": "The UUID of the edge to be invalidated (the part before the | character)",
"reason": "Brief explanation of why this edge is being invalidated"
"fact": "Updated fact of the edge"
}}
]
}}

If no relationships need to be invalidated, return an empty list for "invalidated_edges".
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]
Expand Down
2 changes: 1 addition & 1 deletion core/utils/bulk_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ async def dedupe_nodes_bulk(

existing_nodes = await get_relevant_nodes(compressed_nodes, driver)

nodes, partial_uuid_map = await dedupe_extracted_nodes(
nodes, partial_uuid_map, _ = await dedupe_extracted_nodes(
llm_client, compressed_nodes, existing_nodes
)

Expand Down
51 changes: 49 additions & 2 deletions core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from core.nodes import EntityNode, EpisodicNode
from core.edges import EpisodicEdge, EntityEdge
import logging

from core.utils.maintenance.temporal_operations import NodeEdgeNodeTriplet
from core.prompts import prompt_library
from core.llm_client import LLMClient

Expand Down Expand Up @@ -196,6 +196,53 @@ async def extract_edges(
return edges


def create_edge_identifier(
source_node: EntityNode, edge: EntityEdge, target_node: EntityNode
) -> str:
return f"{source_node.name}-{edge.name}-{target_node.name}"


async def dedupe_extracted_edges_v2(
llm_client: LLMClient,
extracted_edges: list[NodeEdgeNodeTriplet],
existing_edges: list[NodeEdgeNodeTriplet],
) -> list[NodeEdgeNodeTriplet]:
# Create edge map
edge_map = {}
for n1, edge, n2 in existing_edges:
edge_map[create_edge_identifier(n1, edge, n2)] = edge
for n1, edge, n2 in extracted_edges:
if create_edge_identifier(n1, edge, n2) in edge_map.keys():
continue
edge_map[create_edge_identifier(n1, edge, n2)] = edge

# Prepare context for LLM
context = {
"extracted_edges": [
{"triplet": create_edge_identifier(n1, edge, n2), "fact": edge.fact}
for n1, edge, n2 in extracted_edges
],
"existing_edges": [
{"triplet": create_edge_identifier(n1, edge, n2), "fact": edge.fact}
for n1, edge, n2 in extracted_edges
],
}
logger.info(prompt_library.dedupe_edges.v2(context))
llm_response = await llm_client.generate_response(
prompt_library.dedupe_edges.v2(context)
)
new_edges_data = llm_response.get("new_edges", [])
logger.info(f"Extracted new edges: {new_edges_data}")

# Get full edge data
edges = []
for edge_data in new_edges_data:
edge = edge_map[edge_data["triplet"]]
edges.append(edge)

return edges


async def dedupe_extracted_edges(
llm_client: LLMClient,
extracted_edges: list[EntityEdge],
Expand All @@ -206,7 +253,7 @@ async def dedupe_extracted_edges(
for edge in existing_edges:
edge_map[edge.fact] = edge
for edge in extracted_edges:
if edge.fact in edge_map.keys():
if edge.fact in edge_map:
continue
edge_map[edge.fact] = edge

Expand Down
20 changes: 16 additions & 4 deletions core/utils/maintenance/node_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ async def dedupe_extracted_nodes(
for node in existing_nodes:
node_map[node.name] = node

# Temp hack
new_nodes_map = {}
for node in extracted_nodes:
new_nodes_map[node.name] = node

# Prepare context for LLM
existing_nodes_context = [
{"name": node.name, "summary": node.summary} for node in existing_nodes
Expand All @@ -145,20 +150,27 @@ async def dedupe_extracted_nodes(

uuid_map = {}
for duplicate in duplicate_data:
uuid = node_map[duplicate["name"]].uuid
uuid = new_nodes_map[duplicate["name"]].uuid
uuid_value = node_map[duplicate["duplicate_of"]].uuid
uuid_map[uuid] = uuid_value

nodes = []
brand_new_nodes = []
for node in extracted_nodes:
if node.uuid in uuid_map:
existing_name = uuid_map[node.name]
existing_node = node_map[existing_name]
existing_uuid = uuid_map[node.uuid]
# TODO(Preston): This is a bit of a hack I implemented because we were getting incorrect uuids for existing nodes,
# can you revisit the node dedup function and make it somewhat cleaner and add more comments/tests please?
# find an existing node by the uuid from the nodes_map (each key is name, so we need to iterate by uuid value)
existing_node = next(
(v for k, v in node_map.items() if v.uuid == existing_uuid), None
)
nodes.append(existing_node)
continue
brand_new_nodes.append(node)
nodes.append(node)

return nodes, uuid_map
return nodes, uuid_map, brand_new_nodes


async def dedupe_node_list(
Expand Down
Loading