Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add extract nodes from text prompt #106

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion graphiti_core/prompts/extract_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@ class Prompt(Protocol):
v1: PromptVersion
v2: PromptVersion
extract_json: PromptVersion
extract_text: PromptVersion


class Versions(TypedDict):
v1: PromptFunction
v2: PromptFunction
extract_json: PromptFunction
extract_text: PromptFunction


def v1(context: dict[str, Any]) -> list[Message]:
Expand Down Expand Up @@ -144,4 +146,44 @@ def extract_json(context: dict[str, Any]) -> list[Message]:
]


versions: Versions = {'v1': v1, 'v2': v2, 'extract_json': extract_json}
def extract_text(context: dict[str, Any]) -> list[Message]:
sys_prompt = """You are an AI assistant that extracts entity nodes from conversational text. Your primary task is to identify and extract the speaker and other significant entities mentioned in the conversation."""

user_prompt = f"""
Given the following conversation, extract entity nodes from the CURRENT MESSAGE that are explicitly or implicitly mentioned:

Conversation:
{json.dumps([ep['content'] for ep in context['previous_episodes']], indent=2)}
<CURRENT MESSAGE>
{context["episode_content"]}

Guidelines:
2. Extract significant entities, concepts, or actors mentioned in the conversation.
3. Provide concise but informative summaries for each extracted node.
4. Avoid creating nodes for relationships or actions.
5. Avoid creating nodes for temporal information like dates, times or years (these will be added to edges later).
6. Be as explicit as possible in your node names, using full names and avoiding abbreviations.

Respond with a JSON object in the following format:
{{
"extracted_nodes": [
{{
"name": "Unique identifier for the node (use the speaker's name for speaker nodes)",
"labels": ["Entity", "OptionalAdditionalLabel"],
"summary": "Brief summary of the node's role or significance"
}}
]
}}
"""
return [
Message(role='system', content=sys_prompt),
Message(role='user', content=user_prompt),
]


versions: Versions = {
'v1': v1,
'v2': v2,
'extract_json': extract_json,
'extract_text': extract_text,
}
27 changes: 26 additions & 1 deletion graphiti_core/utils/maintenance/node_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,29 @@ async def extract_message_nodes(
return extracted_node_data


async def extract_text_nodes(
llm_client: LLMClient, episode: EpisodicNode, previous_episodes: list[EpisodicNode]
) -> list[dict[str, Any]]:
# Prepare context for LLM
context = {
'episode_content': episode.content,
'episode_timestamp': episode.valid_at.isoformat(),
'previous_episodes': [
{
'content': ep.content,
'timestamp': ep.valid_at.isoformat(),
}
for ep in previous_episodes
],
}

llm_response = await llm_client.generate_response(
prompt_library.extract_nodes.extract_text(context)
)
extracted_node_data = llm_response.get('extracted_nodes', [])
return extracted_node_data


async def extract_json_nodes(
llm_client: LLMClient,
episode: EpisodicNode,
Expand All @@ -73,8 +96,10 @@ async def extract_nodes(
) -> list[EntityNode]:
start = time()
extracted_node_data: list[dict[str, Any]] = []
if episode.source in [EpisodeType.message, EpisodeType.text]:
if episode.source == EpisodeType.message:
extracted_node_data = await extract_message_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.text:
extracted_node_data = await extract_text_nodes(llm_client, episode, previous_episodes)
elif episode.source == EpisodeType.json:
extracted_node_data = await extract_json_nodes(llm_client, episode)

Expand Down