Skip to content

Commit

Permalink
add extract nodes from text prompt (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
prasmussen15 authored Sep 11, 2024
1 parent b214baa commit 4122d35
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 2 deletions.
44 changes: 43 additions & 1 deletion graphiti_core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
extract_json: PromptVersion
extract_text: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
extract_json: PromptFunction
extract_text: PromptFunction


def v1(context: dict[str, Any]) -> list[Message]:
Expand Down Expand Up @@ -144,4 +146,44 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
]


versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}
def extract_text(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""

user_prompt = f"""
Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:
Conversation:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
<CURRENT MESSAGE>
{context["episode_content"]}
Guidelines:
2. Extract significant entities, concepts, or actors mentioned in the conversation.
3. Provide concise but informative summaries for each extracted node.
4. Avoid creating nodes for relationships or actions.
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.
Respond with a JSON object in the following format:
{{
"extracted_nodes": [
{{
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
"labels": ["Entity", "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]


versions: Versions = {
'v1': v1,
'v2': v2,
'extract_json': extract_json,
'extract_text': extract_text,
}
27 changes: 26 additions & 1 deletion graphiti_core/utils/maintenance/node_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@ async def extract_message_nodes(
return extracted_node_data


async def extract_text_nodes(
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [
{
'content': ep.content,
'timestamp': ep.valid_at.isoformat(),
}
for ep in previous_episodes
],
}

llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context)
)
extracted_node_data = llm_response.get('extracted_nodes', [])
return extracted_node_data


async def extract_json_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
Expand All @@ -73,8 +96,10 @@ async def extract_nodes(
) -> list[EntityNode]:
start = time()
extracted_node_data: list[dict[str, Any]] = []
if episode.source in [EpisodeType.message, EpisodeType.text]:
if episode.source == EpisodeType.message:
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.text:
extracted_node_data = await extract_text_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.json:
extracted_node_data = await extract_json_nodes(llm_client, episode)

Expand Down

0 comments on commit 4122d35

Please sign in to comment.