Skip to content

Commit

Permalink
tests on runner examples
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 committed Sep 5, 2024
1 parent cff363f commit 0a01c54
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 130 deletions.
5 changes: 3 additions & 2 deletions examples/ecommerce/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async def main():

async def ingest_products_data(client: Graphiti):
script_dir = Path(__file__).parent
json_file_path = script_dir / 'allbirds_products.json'
json_file_path = script_dir / '../data/manybirds_products.json'

with open(json_file_path) as file:
products = json.load(file)['products']
Expand All @@ -110,7 +110,8 @@ async def ingest_products_data(client: Graphiti):
for i, product in enumerate(products)
]

await client.add_episode_bulk(episodes)
[await client.add_episode(episode.name, episode.content, episode.source_description, episode.reference_time,
episode.source) for episode in episodes]


asyncio.run(main())
2 changes: 1 addition & 1 deletion examples/podcast/podcast_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ async def main(use_bulk: bool = True):
await client.add_episode_bulk(episodes)


asyncio.run(main(True))
asyncio.run(main(False))
148 changes: 65 additions & 83 deletions graphiti_core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,9 @@ async def build_indices_and_constraints(self):
await build_indices_and_constraints(self.driver)

async def retrieve_episodes(
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
self,
reference_time: datetime,
last_n: int = EPISODE_WINDOW_LEN,
) -> list[EpisodicNode]:
"""
Retrieve the last n episodic nodes from the graph.
Expand Down Expand Up @@ -210,14 +210,14 @@ async def retrieve_episodes(
return await retrieve_episodes(self.driver, reference_time, last_n)

async def add_episode(
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
self,
name: str,
episode_body: str,
source_description: str,
reference_time: datetime,
source: EpisodeType = EpisodeType.message,
success_callback: Callable | None = None,
error_callback: Callable | None = None,
):
"""
Process an episode and update the graph.
Expand Down Expand Up @@ -293,7 +293,7 @@ async def add_episode_endpoint(episode_data: EpisodeData):
*[node.generate_name_embedding(embedder) for node in extracted_nodes]
)

# Resolve extracted nodes with nodes already in the graph
# Resolve extracted nodes with nodes already in the graph and extract facts
existing_nodes_lists: list[list[EntityNode]] = list(
await asyncio.gather(
*[get_relevant_nodes([node], self.driver) for node in extracted_nodes]
Expand All @@ -302,22 +302,22 @@ async def add_episode_endpoint(episode_data: EpisodeData):

logger.info(f'Extracted nodes: {[(n.name, n.uuid) for n in extracted_nodes]}')

mentioned_nodes, _ = await resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists
(mentioned_nodes, uuid_map), extracted_edges = await asyncio.gather(resolve_extracted_nodes(
self.llm_client, extracted_nodes, existing_nodes_lists), extract_edges(
self.llm_client, episode, extracted_nodes, previous_episodes
)
)
logger.info(f'Adjusted mentioned nodes: {[(n.name, n.uuid) for n in mentioned_nodes]}')
nodes.extend(mentioned_nodes)

# Extract facts as edges given entity nodes
extracted_edges = await extract_edges(
self.llm_client, episode, mentioned_nodes, previous_episodes
)
extracted_edges_with_resolved_pointers = resolve_edge_pointers(extracted_edges, uuid_map)

# calculate embeddings
await asyncio.gather(*[edge.generate_embedding(embedder) for edge in extracted_edges])
await asyncio.gather(
*[edge.generate_embedding(embedder) for edge in extracted_edges_with_resolved_pointers])

# Resolve extracted edges with edges already in the graph
existing_edges_list: list[list[EntityEdge]] = list(
# Resolve extracted edges with related edges already in the graph
related_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
Expand All @@ -327,74 +327,56 @@ async def add_episode_endpoint(episode_data: EpisodeData):
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges
for edge in extracted_edges_with_resolved_pointers
]
)
)
logger.info(
f'Existing edges lists: {[(e.name, e.uuid) for edges_lst in existing_edges_list for e in edges_lst]}'
f'Related edges lists: {[(e.name, e.uuid) for edges_lst in related_edges_list for e in edges_lst]}'
)
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges]}')
logger.info(f'Extracted edges: {[(e.name, e.uuid) for e in extracted_edges_with_resolved_pointers]}')

deduped_edges: list[EntityEdge] = await resolve_extracted_edges(
self.llm_client, extracted_edges, existing_edges_list
existing_source_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
self.driver,
[edge],
edge.source_node_uuid,
None,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)

# Extract dates for the newly extracted edges
edge_dates = await asyncio.gather(
*[
extract_edge_dates(
self.llm_client,
edge,
episode,
previous_episodes,
)
for edge in deduped_edges
]
existing_target_edges_list: list[list[EntityEdge]] = list(
await asyncio.gather(
*[
get_relevant_edges(
self.driver,
[edge],
None,
edge.target_node_uuid,
RELEVANT_SCHEMA_LIMIT,
)
for edge in extracted_edges_with_resolved_pointers
]
)
)

for i, edge in enumerate(deduped_edges):
valid_at = edge_dates[i][0]
invalid_at = edge_dates[i][1]

edge.valid_at = valid_at
edge.invalid_at = invalid_at
if edge.invalid_at is not None:
edge.expired_at = now
existing_edges_list: list[list[EntityEdge]] = [source_lst + target_lst for source_lst, target_lst in
zip(existing_source_edges_list, existing_target_edges_list)]

entity_edges.extend(deduped_edges)

existing_edges: list[EntityEdge] = [
e for edge_lst in existing_edges_list for e in edge_lst
]

(
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
) = prepare_edges_for_invalidation(
existing_edges=existing_edges, new_edges=deduped_edges, nodes=nodes
resolved_edges, invalidated_edges = await resolve_extracted_edges(
self.llm_client, extracted_edges_with_resolved_pointers, related_edges_list, existing_edges_list,
episode, previous_episodes
)

invalidated_edges = await invalidate_edges(
self.llm_client,
old_edges_with_nodes_pending_invalidation,
new_edges_with_nodes,
episode,
previous_episodes,
)

for edge in invalidated_edges:
for existing_edge in existing_edges:
if existing_edge.uuid == edge.uuid:
existing_edge.expired_at = edge.expired_at
for deduped_edge in deduped_edges:
if deduped_edge.uuid == edge.uuid:
deduped_edge.expired_at = edge.expired_at
logger.info(f'Invalidated edges: {[(e.name, e.uuid) for e in invalidated_edges]}')

entity_edges.extend(existing_edges)
entity_edges.extend(resolved_edges + invalidated_edges)

logger.info(f'Deduped edges: {[(e.name, e.uuid) for e in deduped_edges]}')
logger.info(f'Resolved edges: {[(e.name, e.uuid) for e in resolved_edges]}')

episodic_edges: list[EpisodicEdge] = build_episodic_edges(
mentioned_nodes,
Expand Down Expand Up @@ -422,8 +404,8 @@ async def add_episode_endpoint(episode_data: EpisodeData):
raise e

async def add_episode_bulk(
self,
bulk_episodes: list[RawEpisode],
self,
bulk_episodes: list[RawEpisode],
):
"""
Process multiple episodes in bulk and update the graph.
Expand Down Expand Up @@ -587,18 +569,18 @@ async def search(self, query: str, center_node_uuid: str | None = None, num_resu
return edges

async def _search(
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
self,
query: str,
timestamp: datetime,
config: SearchConfig,
center_node_uuid: str | None = None,
):
return await hybrid_search(
self.driver, self.llm_client.get_embedder(), query, timestamp, config, center_node_uuid
)

async def get_nodes_by_query(
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
self, query: str, limit: int = RELEVANT_SCHEMA_LIMIT
) -> list[EntityNode]:
"""
Retrieve nodes from the graph database based on a text query.
Expand Down
2 changes: 1 addition & 1 deletion graphiti_core/prompts/dedupe_edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
Given the following context, determine whether the New Edge represents any of the edges in the list of Existing Edges.
Existing Edges:
{json.dumps(context['existing_edges'], indent=2)}
{json.dumps(context['related_edges'], indent=2)}
New Edge:
{json.dumps(context['extracted_edges'], indent=2)}
Expand Down
87 changes: 63 additions & 24 deletions graphiti_core/utils/maintenance/edge_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from graphiti_core.llm_client import LLMClient
from graphiti_core.nodes import EntityNode, EpisodicNode
from graphiti_core.prompts import prompt_library
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates
from graphiti_core.utils.maintenance.temporal_operations import extract_edge_dates, get_edge_contradictions

logger = logging.getLogger(__name__)

Check failure on line 29 in graphiti_core/utils/maintenance/edge_operations.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

graphiti_core/utils/maintenance/edge_operations.py:17:1: I001 Import block is un-sorted or un-formatted

Expand Down Expand Up @@ -154,17 +154,13 @@ async def resolve_extracted_edges(
existing_edges_lists: list[list[EntityEdge]],
current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode],
) -> list[EntityEdge]:
) -> tuple[list[EntityEdge], list[EntityEdge]]:
# resolve edges with related edges in the graph, extract temporal information, and find invalidation candidates
results: list[tuple[EntityEdge, tuple[datetime | None, datetime | None]]] = list(
results: list[tuple[EntityEdge, list[EntityEdge]]] = list(
await asyncio.gather(
*[
(
resolve_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(
llm_client, extracted_edge, current_episode, previous_episodes
)
)
resolve_extracted_edge(llm_client, extracted_edge, related_edges, existing_edges, current_episode,
previous_episodes)
for extracted_edge, related_edges, existing_edges in zip(
extracted_edges, related_edges_lists, existing_edges_lists
)
Expand All @@ -173,26 +169,69 @@ async def resolve_extracted_edges(
)

resolved_edges: list[EntityEdge] = []
invalidated_edges: list[EntityEdge] = []
for result in results:
resolved_edge = result[0]
valid_at, invalid_at = result[1]
resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = (
invalid_at if invalid_at is not None else resolved_edge.invalid_at
)
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = datetime.now()

return resolved_edges


async def resolve_extracted_edge(
invalidated_edge_chunk = result[1]

resolved_edges.append(resolved_edge)
invalidated_edges.extend(invalidated_edge_chunk)

return resolved_edges, invalidated_edges


async def resolve_extracted_edge(llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge],
existing_edges: list[EntityEdge], current_episode: EpisodicNode,
previous_episodes: list[EpisodicNode]) -> tuple[EntityEdge, list[EntityEdge]]:
resolved_edge, (valid_at, invalid_at), invalidation_candidates = await asyncio.gather(
dedupe_extracted_edge(llm_client, extracted_edge, related_edges),
extract_edge_dates(
llm_client, extracted_edge, current_episode, previous_episodes
), get_edge_contradictions(llm_client, extracted_edge, existing_edges))

now = datetime.now()

resolved_edge.valid_at = valid_at if valid_at is not None else resolved_edge.valid_at
resolved_edge.invalid_at = invalid_at if invalid_at is not None else resolved_edge.invalid_at
if invalid_at is not None and resolved_edge.expired_at is None:
resolved_edge.expired_at = now

# Determine if the new_edge needs to be expired
if resolved_edge.expired_at is None:
invalidation_candidates.sort(key=lambda c: (c.valid_at is None, c.valid_at))
for candidate in invalidation_candidates:
if candidate.valid_at is not None and resolved_edge.valid_at is not None:
if candidate.valid_at > resolved_edge.valid_at:

Check failure on line 204 in graphiti_core/utils/maintenance/edge_operations.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (SIM102)

graphiti_core/utils/maintenance/edge_operations.py:203:13: SIM102 Use a single `if` statement instead of nested `if` statements
# Expire new edge since we have information about more recent events
resolved_edge.invalid_at = candidate.valid_at
resolved_edge.expired_at = now
break

# Determine which contradictory edges need to be expired
invalidated_edges: list[EntityEdge] = []
for edge in invalidated_edges:
# Edge invalid before new edge becomes valid
if edge.invalid_at is not None and resolved_edge.valid_at is not None and edge.invalid_at < resolved_edge.valid_at:
continue
# New edge invalid before edge becomes valid
elif edge.valid_at is not None and resolved_edge.invalid_at is not None and resolved_edge.invalid_at < edge.valid_at:
continue

Check failure on line 218 in graphiti_core/utils/maintenance/edge_operations.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (SIM114)

graphiti_core/utils/maintenance/edge_operations.py:214:9: SIM114 Combine `if` branches using logical `or` operator
# New edge invalidates edge
elif edge.valid_at is not None and resolved_edge.valid_at is not None and edge.valid_at < resolved_edge.valid_at:
edge.invalid_at = resolved_edge.valid_at
edge.expired_at = edge.expired_at if edge.expired_at is not None else now
invalidated_edges.append(edge)

return resolved_edge, invalidated_edges


async def dedupe_extracted_edge(
llm_client: LLMClient, extracted_edge: EntityEdge, related_edges: list[EntityEdge]
) -> EntityEdge:
start = time()

# Prepare context for LLM
existing_edges_context = [
related_edges_context = [
{'uuid': edge.uuid, 'name': edge.name, 'fact': edge.fact} for edge in related_edges
]

Expand All @@ -203,7 +242,7 @@ async def resolve_extracted_edge(
}

context = {
'existing_edges': existing_edges_context,
'related_edges': related_edges_context,
'extracted_edges': extracted_edge_context,
}

Expand All @@ -221,7 +260,7 @@ async def resolve_extracted_edge(

end = time()
logger.info(
f'Resolved node: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
f'Resolved Edge: {extracted_edge.name} is {edge.name}, in {(end - start) * 1000} ms'
)

return edge
Expand Down
Loading

0 comments on commit 0a01c54

Please sign in to comment.