Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add real world dates extraction #26

Merged
merged 9 commits into from
Aug 23, 2024
10 changes: 9 additions & 1 deletion core/graphiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
)
from core.utils.maintenance.node_operations import dedupe_extracted_nodes, extract_nodes
from core.utils.maintenance.temporal_operations import (
extract_edge_dates,
extract_node_edge_node_triplet,
invalidate_edges,
prepare_edges_for_invalidation,
)
Expand Down Expand Up @@ -174,7 +176,13 @@ async def add_episode(
for deduped_edge in deduped_edges:
if deduped_edge.uuid not in [edge.uuid for edge in invalidated_edges]:
edges_to_save.append(deduped_edge)

for deduped_edge in deduped_edges:
triplet = extract_node_edge_node_triplet(deduped_edge, nodes)
valid_at, invalid_at, _ = await extract_edge_dates(
self.llm_client, triplet, episode.valid_at, episode, previous_episodes
)
deduped_edge.valid_at = valid_at
deduped_edge.invalid_at = invalid_at
entity_edges.extend(edges_to_save)

edge_touched_node_uuids = list(set(edge_touched_node_uuids))
Expand Down
62 changes: 62 additions & 0 deletions core/prompts/extract_edge_dates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Any, Protocol, TypedDict

from .models import Message, PromptFunction, PromptVersion


class Prompt(Protocol):
v1: PromptVersion


class Versions(TypedDict):
v1: PromptFunction


def v1(context: dict[str, Any]) -> list[Message]:
return [
Message(
role='system',
content='You are an AI assistant that extracts datetime information for graph edges, focusing only on dates directly related to the establishment or change of the relationship described in the edge fact.',
),
Message(
role='user',
content=f"""
Edge:
Source Node: {context['source_node']}
Edge Name: {context['edge_name']}
Target Node: {context['target_node']}
Fact: {context['edge_fact']}

Current Episode: {context['current_episode']}
Previous Episodes: {context['previous_episodes']}
Reference Timestamp: {context['reference_timestamp']}

IMPORTANT: Only extract dates that are part of the provided fact. Otherwise ignore the date.

Definitions:
- valid_at: The date and time when the relationship described by the edge fact became true or was established.
- invalid_at: The date and time when the relationship described by the edge fact stopped being true or ended.

Task:
Analyze the conversation and determine if there are dates that are part of the edge fact. Only set dates if they explicitly relate to the formation or alteration of the relationship itself.

Guidelines:
1. Use ISO 8601 format (YYYY-MM-DDTHH:MM:SSZ) for datetimes.
2. Use the reference timestamp as the current time when determining the valid_at and invalid_at dates.
3. If no relevant dates are found that explicitly establish or change the relationship, leave the fields as null.
4. Do not infer dates from related events. Only use dates that are directly stated to establish or change the relationship.
5. For relative time mentions directly related to the relationship, calculate the actual datetime based on the reference timestamp.
6. If only a date is mentioned without a specific time, use 00:00:00 (midnight) for that date.
7. If only a year is mentioned, use January 1st of that year at 00:00:00.
9. Always include the time zone offset (use Z for UTC if no specific time zone is mentioned).
Respond with a JSON object:
{{
"valid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
"invalid_at": "YYYY-MM-DDTHH:MM:SSZ or null",
"explanation": "Brief explanation of why these dates were chosen or why they were set to null"
}}
""",
),
]


versions: Versions = {'v1': v1}
1 change: 1 addition & 0 deletions core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def v3(context: dict[str, Any]) -> list[Message]:
2. Extract other significant entities, concepts, or actors mentioned in the conversation.
3. Provide concise but informative summaries for each extracted node.
4. Avoid creating nodes for relationships or actions.
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).

Respond with a JSON object in the following format:
{{
Expand Down
12 changes: 12 additions & 0 deletions core/prompts/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
from .dedupe_nodes import (
versions as dedupe_nodes_versions,
)
from .extract_edge_dates import (
Prompt as ExtractEdgeDatesPrompt,
)
from .extract_edge_dates import (
Versions as ExtractEdgeDatesVersions,
)
from .extract_edge_dates import (
versions as extract_edge_dates_versions,
)
from .extract_edges import (
Prompt as ExtractEdgesPrompt,
)
Expand Down Expand Up @@ -54,6 +63,7 @@ class PromptLibrary(Protocol):
extract_edges: ExtractEdgesPrompt
dedupe_edges: DedupeEdgesPrompt
invalidate_edges: InvalidateEdgesPrompt
extract_edge_dates: ExtractEdgeDatesPrompt


class PromptLibraryImpl(TypedDict):
Expand All @@ -62,6 +72,7 @@ class PromptLibraryImpl(TypedDict):
extract_edges: ExtractEdgesVersions
dedupe_edges: DedupeEdgesVersions
invalidate_edges: InvalidateEdgesVersions
extract_edge_dates: ExtractEdgeDatesVersions


class VersionWrapper:
Expand Down Expand Up @@ -90,5 +101,6 @@ def __init__(self, library: PromptLibraryImpl):
'extract_edges': extract_edges_versions,
'dedupe_edges': dedupe_edges_versions,
'invalidate_edges': invalidate_edges_versions,
'extract_edge_dates': extract_edge_dates_versions,
}
prompt_library: PromptLibrary = PromptLibraryWrapper(PROMPT_LIBRARY_IMPL) # type: ignore[assignment]
40 changes: 40 additions & 0 deletions core/utils/maintenance/temporal_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,43 @@ def process_edge_invalidation_llm_response(
f"Invalidated edge: {edge_to_update.name} (UUID: {edge_to_update.uuid}). Updated Fact: {edge_to_invalidate['fact']}"
)
return invalidated_edges


async def extract_edge_dates(
llm_client: LLMClient,
edge_triplet: NodeEdgeNodeTriplet,
reference_time: datetime,
current_episode: EpisodicNode,
previous_episodes: List[EpisodicNode],
) -> tuple[datetime | None, datetime | None, str]:
source_node, edge, target_node = edge_triplet

context = {
'source_node': source_node.name,
'edge_name': edge.name,
'target_node': target_node.name,
'edge_fact': edge.fact,
'current_episode': current_episode.content,
'previous_episodes': [ep.content for ep in previous_episodes],
'reference_timestamp': reference_time.isoformat(),
}
llm_response = await llm_client.generate_response(prompt_library.extract_edge_dates.v1(context))

valid_at = llm_response.get('valid_at')
invalid_at = llm_response.get('invalid_at')
explanation = llm_response.get('explanation', '')

valid_at_datetime = (
datetime.fromisoformat(valid_at.replace('Z', '+00:00'))
paul-paliychuk marked this conversation as resolved.
Show resolved Hide resolved
if valid_at and valid_at != ''
else None
)
invalid_at_datetime = (
datetime.fromisoformat(invalid_at.replace('Z', '+00:00'))
if invalid_at and invalid_at != ''
else None
)

logger.info(f'Edge date extraction explanation: {explanation}')

return valid_at_datetime, invalid_at_datetime, explanation
97 changes: 73 additions & 24 deletions runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,36 +37,85 @@ def setup_logging():
return logger


bmw_sales = [
{
'episode_body': 'Paul (buyer): Hi, I would like to buy a new car',
},
{
'episode_body': 'Dan The Salesman (salesman): Sure, I can help you with that. What kind of car are you looking for?',
},
{
'episode_body': 'Paul (buyer): I am looking for a new BMW',
},
{
'episode_body': 'Dan The Salesman (salesman): Great choice! What kind of BMW are you looking for?',
},
{
'episode_body': 'Paul (buyer): I am considering a BMW 3 series',
},
{
'episode_body': 'Dan The Salesman (salesman): Great choice, we currently have a 2024 BMW 3 series in stock, it is a great car and costs $50,000',
},
{
'episode_body': "Paul (buyer): Actually I am interested in something cheaper, I won't consider anything over $30,000",
},
]

dates_mentioned = [
{
'episode_body': 'Paul (user): I have graduated from Univerity of Toronto in 2022',
},
{
'episode_body': 'Jane (user): How cool, I graduated from the same school in 1999',
},
]

times_mentioned = [
{
'episode_body': 'Paul (user): 15 minutes ago we put a deposit on our new house',
},
]

time_range_mentioned = [
{
'episode_body': 'Paul (user): I served as a US Marine in 2015-2019',
},
]

relative_time_range_mentioned = [
{
'episode_body': 'Paul (user): I served as a US Marine in for 20 years, until retiring last month',
},
]


async def main():
setup_logging()
client = Graphiti(neo4j_uri, neo4j_user, neo4j_password)
await clear_data(client.driver)
await client.build_indices_and_constraints()

# await client.build_indices()
await client.add_episode(
name='Message 3',
episode_body='Jane: I am married to Paul',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
await client.add_episode(
name='Message 4',
episode_body='Paul: I have divorced Jane',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
await client.add_episode(
name='Message 5',
episode_body='Jane: I miss Paul',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
await client.add_episode(
name='Message 6',
episode_body='Jane: I dont miss Paul anymore, I hate him',
source_description='WhatsApp Message',
reference_time=datetime.now(),
)
for i, message in enumerate(relative_time_range_mentioned):
await client.add_episode(
name=f'Message {i}',
episode_body=message['episode_body'],
source_description='',
# reference_time=datetime.now() - timedelta(days=365 * 3),
reference_time=datetime.now(),
)
# await client.add_episode(
# name='Message 5',
# episode_body='Jane: I miss Paul',
# source_description='WhatsApp Message',
# reference_time=datetime.now(),
# )
# await client.add_episode(
# name='Message 6',
# episode_body='Jane: I dont miss Paul anymore, I hate him',
# source_description='WhatsApp Message',
# reference_time=datetime.now(),
# )

# await client.add_episode(
# name="Message 3",
Expand Down