refactor(tool-node): introduce specific exceptions for tool node errors (#10357)

This commit is contained in:
-LAN- 2024-11-07 14:02:38 +08:00 committed by Joel
parent 47f638e5aa
commit 39fdcfd7e9
2 changed files with 31 additions and 9 deletions

View File

@ -0,0 +1,16 @@
class ToolNodeError(ValueError):
"""Base exception for tool node errors."""
pass
class ToolParameterError(ToolNodeError):
"""Exception raised for errors in tool parameters."""
pass
class ToolFileError(ToolNodeError):
"""Exception raised for errors related to tool files."""
pass

View File

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.models import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
@ -15,12 +15,18 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db from extensions.ext_database import db
from models import ToolFile from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import ToolNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)
class ToolNode(BaseNode[ToolNodeData]): class ToolNode(BaseNode[ToolNodeData]):
""" """
@ -42,7 +48,7 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_runtime = ToolManager.get_workflow_tool_runtime( tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
) )
except Exception as e: except ToolNodeError as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs={}, inputs={},
@ -75,7 +81,7 @@ class ToolNode(BaseNode[ToolNodeData]):
workflow_call_depth=self.workflow_call_depth, workflow_call_depth=self.workflow_call_depth,
thread_pool_id=self.thread_pool_id, thread_pool_id=self.thread_pool_id,
) )
except Exception as e: except ToolNodeError as e:
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log, inputs=parameters_for_log,
@ -133,13 +139,13 @@ class ToolNode(BaseNode[ToolNodeData]):
if tool_input.type == "variable": if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value) variable = variable_pool.get(tool_input.value)
if variable is None: if variable is None:
raise ValueError(f"variable {tool_input.value} not exists") raise ToolParameterError(f"Variable {tool_input.value} does not exist")
parameter_value = variable.value parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}: elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value)) segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text parameter_value = segment_group.log if for_log else segment_group.text
else: else:
raise ValueError(f"unknown tool input type '{tool_input.type}'") raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value result[parameter_name] = parameter_value
return result return result
@ -181,7 +187,7 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id) stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt) tool_file = session.scalar(stmt)
if tool_file is None: if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists") raise ToolFileError(f"Tool file {tool_file_id} does not exist")
result.append( result.append(
File( File(
@ -203,7 +209,7 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id) stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt) tool_file = session.scalar(stmt)
if tool_file is None: if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists") raise ToolFileError(f"Tool file {tool_file_id} does not exist")
result.append( result.append(
File( File(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
@ -224,7 +230,7 @@ class ToolNode(BaseNode[ToolNodeData]):
stmt = select(ToolFile).where(ToolFile.id == tool_file_id) stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt) tool_file = session.scalar(stmt)
if tool_file is None: if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists") raise ToolFileError(f"Tool file {tool_file_id} does not exist")
if "." in url: if "." in url:
extension = "." + url.split("/")[-1].split(".")[1] extension = "." + url.split("/")[-1].split(".")[1]
else: else: