Skip to content

Commit

Permalink
Add episode refactor (#85)
Browse files Browse the repository at this point in the history
* temp commit while moving

* fix name embedding bug

* invalidation

* format

* tests on runner examples

* format

* ellipsis

* ruff

* fix

* format

* minor prompt change
  • Loading branch information
prasmussen15 authored Sep 5, 2024
1 parent 1d31442 commit 2990211
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 106 deletions.
11 changes: 9 additions & 2 deletions examples/ecommerce/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def main():

async def ingest_products_data(client: Graphiti):
script_dir = Path(__file__).parent
json_file_path = script_dir / 'allbirds_products.json'
json_file_path = script_dir / '../data/manybirds_products.json'

with open(json_file_path) as file:
products = json.load(file)['products']
Expand All @@ -110,7 +110,14 @@ async def ingest_products_data(client: Graphiti):
for i, product in enumerate(products)
]

await client.add_episode_bulk(episodes)
for episode in episodes:
await client.add_episode(
episode.name,
episode.content,
episode.source_description,
episode.reference_time,
episode.source,
)


asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
await client.add_episode_bulk(episodes)


asyncio.run(main(True))
asyncio.run(main(False))
120 changes: 56 additions & 64 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,6 @@
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.maintenance.temporal_operations import (
extract_edge_dates,
invalidate_edges,
prepare_edges_for_invalidation,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -293,7 +288,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)

# Resolve extracted nodes with nodes already in the graph
# Resolve extracted nodes with nodes already in the graph and extract facts
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
Expand All @@ -302,22 +297,27 @@ async def add_episode_endpoint(episode_data: EpisodeData):

logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

mentioned_nodes, _ = await resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(
resolve_extracted_nodes(self.llm_client, extracted_nodes, existing_nodes_lists),
extract_edges(self.llm_client, episode, extracted_nodes, previous_episodes),
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)

# Extract facts as edges given entity nodes
extracted_edges = await extract_edges(
self.llm_client, episode, mentioned_nodes, previous_episodes
extracted_edges_with_resolved_pointers = resolve_edge_pointers(
extracted_edges, uuid_map
)

# calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
await asyncio.gather(
*[
edge.generate_embedding(embedder)
for edge in extracted_edges_with_resolved_pointers
]
)

# Resolve extracted edges with edges already in the graph
existing_edges_list: list[list[EntityEdge]] = list(
# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
Expand All @@ -327,74 +327,66 @@ async def add_episode_endpoint(episode_data: EpisodeData):
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges
for edge in extracted_edges_with_resolved_pointers
]
)
)
logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
)
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')

deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, extracted_edges, existing_edges_list
logger.info(
f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}'
)

# Extract dates for the newly extracted edges
edge_dates = await asyncio.gather(
*[
extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
for edge in deduped_edges
]
existing_source_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
None,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)

for i, edge in enumerate(deduped_edges):
valid_at = edge_dates[i][0]
invalid_at = edge_dates[i][1]

edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at is not None:
edge.expired_at = now

entity_edges.extend(deduped_edges)
existing_target_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
self.driver,
[edge],
None,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)

existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
existing_edges_list: list[list[EntityEdge]] = [
source_lst + target_lst
for source_lst, target_lst in zip(
existing_source_edges_list, existing_target_edges_list
)
]

(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
) = prepare_edges_for_invalidation(
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
)

invalidated_edges = await invalidate_edges(
resolved_edges, invalidated_edges = await resolve_extracted_edges(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
extracted_edges_with_resolved_pointers,
related_edges_list,
existing_edges_list,
episode,
previous_episodes,
)

for edge in invalidated_edges:
for existing_edge in existing_edges:
if existing_edge.uuid == edge.uuid:
existing_edge.expired_at = edge.expired_at
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')

entity_edges.extend(existing_edges)
entity_edges.extend(resolved_edges + invalidated_edges)

logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')

episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
Expand Down
2 changes: 1 addition & 1 deletion graphiti_core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
{json.dumps(context['related_edges'], indent=2)}
New Edge:
{json.dumps(context['extracted_edges'], indent=2)}
Expand Down
38 changes: 37 additions & 1 deletion graphiti_core/prompts/invalidate_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction


def v1(context: dict[str, Any]) -> list[Message]:
Expand Down Expand Up @@ -71,4 +73,38 @@ def v1(context: dict[str, Any]) -> list[Message]:
]


versions: Versions = {'v1': v1}
def v2(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are an AI assistant that helps determine which relationships in a knowledge graph should be invalidated based solely on explicit contradictions in newer information.',
),
Message(
role='user',
content=f"""
Based on the provided Existing Edges and a New Edge, determine which existing edges, if any, should be marked as invalidated due to contradictions with the New Edge.
Existing Edges:
{context['existing_edges']}
New Edge:
{context['new_edge']}
For each existing edge that should be invalidated, respond with a JSON object in the following format:
{{
"invalidated_edges": [
{{
"uuid": "The UUID of the edge to be invalidated",
"fact": "Updated fact of the edge"
}}
]
}}
If no relationships need to be invalidated based on these strict criteria, return an empty list for "invalidated_edges".
""",
),
]


versions: Versions = {'v1': v1, 'v2': v2}
52 changes: 26 additions & 26 deletions graphiti_core/search/search_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ async def bfs(node_ids: list[str], driver: AsyncDriver):


async def edge_similarity_search(
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
driver: AsyncDriver,
search_vector: list[float],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# vector similarity search over embedded facts
query = Query("""
Expand Down Expand Up @@ -211,7 +211,7 @@ async def edge_similarity_search(


async def entity_similarity_search(
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
search_vector: list[float], driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# vector similarity search over entity names
records, _, _ = await driver.execute_query(
Expand Down Expand Up @@ -247,7 +247,7 @@ async def entity_similarity_search(


async def entity_fulltext_search(
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
query: str, driver: AsyncDriver, limit=RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
# BM25 search to get top nodes
fuzzy_query = re.sub(r'[^\w\s]', '', query) + '~'
Expand Down Expand Up @@ -284,11 +284,11 @@ async def entity_fulltext_search(


async def edge_fulltext_search(
driver: AsyncDriver,
query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
limit=RELEVANT_SCHEMA_LIMIT,
driver: AsyncDriver,
query: str,
source_node_uuid: str | None,
target_node_uuid: str | None,
limit=RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
# fulltext search over facts
cypher_query = Query("""
Expand Down Expand Up @@ -401,10 +401,10 @@ async def edge_fulltext_search(


async def hybrid_node_search(
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
queries: list[str],
embeddings: list[list[float]],
driver: AsyncDriver,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityNode]:
"""
Perform a hybrid search for nodes using both text queries and embeddings.
Expand Down Expand Up @@ -466,8 +466,8 @@ async def hybrid_node_search(


async def get_relevant_nodes(
nodes: list[EntityNode],
driver: AsyncDriver,
nodes: list[EntityNode],
driver: AsyncDriver,
) -> list[EntityNode]:
"""
Retrieve relevant nodes based on the provided list of EntityNodes.
Expand Down Expand Up @@ -503,11 +503,11 @@ async def get_relevant_nodes(


async def get_relevant_edges(
driver: AsyncDriver,
edges: list[EntityEdge],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
driver: AsyncDriver,
edges: list[EntityEdge],
source_node_uuid: str | None,
target_node_uuid: str | None,
limit: int = RELEVANT_SCHEMA_LIMIT,
) -> list[EntityEdge]:
start = time()
relevant_edges: list[EntityEdge] = []
Expand Down Expand Up @@ -557,7 +557,7 @@ def rrf(results: list[list[str]], rank_const=1) -> list[str]:


async def node_distance_reranker(
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
driver: AsyncDriver, results: list[list[str]], center_node_uuid: str
) -> list[str]:
# use rrf as a preliminary ranker
sorted_uuids = rrf(results)
Expand All @@ -579,8 +579,8 @@ async def node_distance_reranker(

for record in records:
if (
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
record['source_uuid'] == center_node_uuid
or record['target_uuid'] == center_node_uuid
):
continue
distance = record['score']
Expand Down
Loading

0 comments on commit 2990211

Please sign in to comment.