Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(tool-node): introduce specific exceptions for tool node errors #10357

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions api/core/workflow/nodes/tool/exc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class ToolNodeError(ValueError):
"""Base exception for tool node errors."""

pass


class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""

pass


class ToolFileError(ToolNodeError):
"""Exception raised for errors related to tool files."""

pass
24 changes: 15 additions & 9 deletions api/core/workflow/nodes/tool/tool_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlalchemy.orm import Session

from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.models import File, FileTransferMethod, FileType
from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
Expand All @@ -15,12 +15,18 @@
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus

from .entities import ToolNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)


class ToolNode(BaseNode[ToolNodeData]):
"""
Expand All @@ -42,7 +48,7 @@ def _run(self) -> NodeRunResult:
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
)
except Exception as e:
except ToolNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
Expand Down Expand Up @@ -75,7 +81,7 @@ def _run(self) -> NodeRunResult:
workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id,
)
except Exception as e:
except ToolNodeError as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
Expand Down Expand Up @@ -133,13 +139,13 @@ def _generate_parameters(
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ValueError(f"variable {tool_input.value} not exists")
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ValueError(f"unknown tool input type '{tool_input.type}'")
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value

return result
Expand Down Expand Up @@ -181,7 +187,7 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage])
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"Tool file {tool_file_id} does not exist")

result.append(
File(
Expand All @@ -203,7 +209,7 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage])
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
result.append(
File(
tenant_id=self.tenant_id,
Expand All @@ -224,7 +230,7 @@ def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage])
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
Expand Down