Skip to content

Commit

Permalink
fix(node): handle empty text segments gracefully
Browse files Browse the repository at this point in the history
Ensure that messages are only created from non-empty text segments, preventing potential issues with empty content.

test: add scenario for file variable handling

Introduce a test case for scenarios involving prompt templates with file variables, particularly images, to improve reliability and test coverage. Updated `LLMNodeTestScenario` to use `Sequence` and `Mapping` for more flexible configurations.

Closes #123, relates to #456.
  • Loading branch information
laipz8200 committed Nov 18, 2024
1 parent d195380 commit 1a8b058
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
6 changes: 4 additions & 2 deletions api/core/workflow/nodes/llm/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,8 +868,10 @@ def _handle_list_messages(
image_contents.append(image_content)

# Create message with text from all segments
prompt_message = _combine_text_message_with_role(text=segment_group.text, role=message.role)
prompt_messages.append(prompt_message)
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role)
prompt_messages.append(prompt_message)

if image_contents:
# Create message with image contents
Expand Down
38 changes: 38 additions & 0 deletions api/tests/unit_tests/core/workflow/nodes/llm/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,49 @@ def test_fetch_prompt_messages__basic(faker, llm_node, model_config):
),
],
),
LLMNodeTestScenario(
description="Prompt template with variable selector of File",
user_query=fake_query,
user_files=[],
vision_enabled=True,
vision_detail=fake_vision_detail,
features=[ModelFeature.VISION],
window_size=fake_window_size,
prompt_template=[
LLMNodeChatModelMessage(
text="{{#input.image#}}",
role=PromptMessageRole.USER,
edition_type="basic",
),
],
expected_messages=[
UserPromptMessage(
content=[
ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail),
]
),
]
+ mock_history[fake_window_size * -2 :]
+ [UserPromptMessage(content=fake_query)],
file_variables={
"input.image": File(
tenant_id="test",
type=FileType.IMAGE,
filename="test1.jpg",
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=fake_remote_url,
)
},
),
]

for scenario in test_scenarios:
model_config.model_schema.features = scenario.features

for k, v in scenario.file_variables.items():
selector = k.split(".")
llm_node.graph_runtime_state.variable_pool.add(selector, v)

# Call the method under test
prompt_messages, _ = llm_node._fetch_prompt_messages(
user_query=scenario.user_query,
Expand Down
13 changes: 9 additions & 4 deletions api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Mapping, Sequence

from pydantic import BaseModel, Field

from core.file import File
Expand All @@ -11,10 +13,13 @@ class LLMNodeTestScenario(BaseModel):

description: str = Field(..., description="Description of the test scenario")
user_query: str = Field(..., description="User query input")
user_files: list[File] = Field(default_factory=list, description="List of user files")
user_files: Sequence[File] = Field(default_factory=list, description="List of user files")
vision_enabled: bool = Field(default=False, description="Whether vision is enabled")
vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled")
features: list[ModelFeature] = Field(default_factory=list, description="List of model features")
features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features")
window_size: int = Field(..., description="Window size for memory")
prompt_template: list[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
expected_messages: list[PromptMessage] = Field(..., description="Expected messages after processing")
prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages")
file_variables: Mapping[str, File | Sequence[File]] = Field(
default_factory=dict, description="List of file variables"
)
expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing")

0 comments on commit 1a8b058

Please sign in to comment.