diff --git a/api/core/workflow/nodes/tool/exc.py b/api/core/workflow/nodes/tool/exc.py new file mode 100644 index 0000000000..7212e8bfc0 --- /dev/null +++ b/api/core/workflow/nodes/tool/exc.py @@ -0,0 +1,16 @@ +class ToolNodeError(ValueError): + """Base exception for tool node errors.""" + + pass + + +class ToolParameterError(ToolNodeError): + """Exception raised for errors in tool parameters.""" + + pass + + +class ToolFileError(ToolNodeError): + """Exception raised for errors related to tool files.""" + + pass diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 0994ccaedb..42e870c46c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.models import File, FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager @@ -15,12 +15,18 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus +from .entities import ToolNodeData +from .exc import ( + ToolFileError, + ToolNodeError, + ToolParameterError, +) + class ToolNode(BaseNode[ToolNodeData]): """ @@ -42,7 +48,7 @@ class ToolNode(BaseNode[ToolNodeData]): tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, @@ -75,7 +81,7 @@ class ToolNode(BaseNode[ToolNodeData]): workflow_call_depth=self.workflow_call_depth, thread_pool_id=self.thread_pool_id, ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, @@ -133,13 +139,13 @@ class ToolNode(BaseNode[ToolNodeData]): if tool_input.type == "variable": variable = variable_pool.get(tool_input.value) if variable is None: - raise ValueError(f"variable {tool_input.value} not exists") + raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value elif tool_input.type in {"mixed", "constant"}: segment_group = variable_pool.convert_template(str(tool_input.value)) parameter_value = segment_group.log if for_log else segment_group.text else: - raise ValueError(f"unknown tool input type '{tool_input.type}'") + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") result[parameter_name] = parameter_value return result @@ -181,7 +187,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") result.append( File( @@ -203,7 +209,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") result.append( File( tenant_id=self.tenant_id, @@ -224,7 +230,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") if "." in url: extension = "." + url.split("/")[-1].split(".")[1] else: