From 8c2d86e3779fe2df1e29c10a02ab3051000f5b38 Mon Sep 17 00:00:00 2001 From: prestonrasmussen Date: Wed, 11 Sep 2024 11:31:43 -0400 Subject: [PATCH] add extract nodes from text prompt --- graphiti_core/prompts/extract_nodes.py | 44 ++++++++++++++++++- .../utils/maintenance/node_operations.py | 27 +++++++++++- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/graphiti_core/prompts/extract_nodes.py b/graphiti_core/prompts/extract_nodes.py index abd47d6b..fd52159d 100644 --- a/graphiti_core/prompts/extract_nodes.py +++ b/graphiti_core/prompts/extract_nodes.py @@ -24,12 +24,14 @@ class Prompt(Protocol): v1: PromptVersion v2: PromptVersion extract_json: PromptVersion + extract_text: PromptVersion class Versions(TypedDict): v1: PromptFunction v2: PromptFunction extract_json: PromptFunction + extract_text: PromptFunction def v1(context: dict[str, Any]) -> list[Message]: @@ -144,4 +146,44 @@ def extract_json(context: dict[str, Any]) -> list[Message]: ] -versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json} +def extract_text(context: dict[str, Any]) -> list[Message]: + sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation.""" + + user_prompt = f""" +Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned: + +Conversation: +{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)} + +{context["episode_content"]} + +Guidelines: +2. Extract significant entities, concepts, or actors mentioned in the conversation. +3. Provide concise but informative summaries for each extracted node. +4. Avoid creating nodes for relationships or actions. +5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later). +6. Be as explicit as possible in your node names, using full names and avoiding abbreviations. + +Respond with a JSON object in the following format: +{{ + "extracted_nodes": [ + {{ + "name": "Unique identifier for the node (use the speaker's name for speaker nodes)", + "labels": ["Entity", "OptionalAdditionalLabel"], + "summary": "Brief summary of the node's role or significance" + }} + ] +}} +""" + return [ + Message(role='system', content=sys_prompt), + Message(role='user', content=user_prompt), + ] + + +versions: Versions = { + 'v1': v1, + 'v2': v2, + 'extract_json': extract_json, + 'extract_text': extract_text, +} diff --git a/graphiti_core/utils/maintenance/node_operations.py b/graphiti_core/utils/maintenance/node_operations.py index 1aa6c757..2da68be1 100644 --- a/graphiti_core/utils/maintenance/node_operations.py +++ b/graphiti_core/utils/maintenance/node_operations.py @@ -48,6 +48,29 @@ async def extract_message_nodes( return extracted_node_data +async def extract_text_nodes( + llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode] +) -> list[dict[str, Any]]: + # Prepare context for LLM + context = { + 'episode_content': episode.content, + 'episode_timestamp': episode.valid_at.isoformat(), + 'previous_episodes': [ + { + 'content': ep.content, + 'timestamp': ep.valid_at.isoformat(), + } + for ep in previous_episodes + ], + } + + llm_response = await llm_client.generate_response( + prompt_library.extract_nodes.extract_text(context) + ) + extracted_node_data = llm_response.get('extracted_nodes', []) + return extracted_node_data + + async def extract_json_nodes( llm_client: LLMClient, episode: EpisodicNode, @@ -73,8 +96,10 @@ async def extract_nodes( ) -> list[EntityNode]: start = time() extracted_node_data: list[dict[str, Any]] = [] - if episode.source in [EpisodeType.message, EpisodeType.text]: + if episode.source == EpisodeType.message: extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes) + elif episode.source == EpisodeType.text: + extracted_node_data = await extract_text_nodes(llm_client, episode, previous_episodes) elif episode.source == EpisodeType.json: extracted_node_data = await extract_json_nodes(llm_client, episode)