refactor(core): improve type annotations and file handling consistency

- Use more precise type annotations with Sequence and Mapping for task entities.
- Ensure raw_prompt is assigned properly after replacement in advanced prompt transform.
- Remove unused generator return type from _fetch_context method.
- Refactor tool node file handling to retrieve more comprehensive file attributes, ensuring file existence validation in the database.
This commit is contained in:
-LAN- 2024-10-15 15:34:14 +08:00
parent dba67cd87a
commit a36ef8430e
4 changed files with 40 additions and 20 deletions

View File

@ -213,7 +213,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
created_by: Optional[dict] = None created_by: Optional[dict] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[list[dict]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []
event: StreamEvent = StreamEvent.WORKFLOW_FINISHED event: StreamEvent = StreamEvent.WORKFLOW_FINISHED
workflow_run_id: str workflow_run_id: str
@ -298,7 +298,7 @@ class NodeFinishStreamResponse(StreamResponse):
execution_metadata: Optional[dict] = None execution_metadata: Optional[dict] = None
created_at: int created_at: int
finished_at: int finished_at: int
files: Optional[list[dict]] = [] files: Optional[Sequence[Mapping[str, Any]]] = []
parallel_id: Optional[str] = None parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None parallel_start_node_id: Optional[str] = None
parent_parallel_id: Optional[str] = None parent_parallel_id: Optional[str] = None

View File

@ -150,7 +150,7 @@ class AdvancedPromptTransform(PromptTransform):
for k, v in inputs.items(): for k, v in inputs.items():
if k.startswith("#"): if k.startswith("#"):
vp.add(k[1:-1].split("."), v) vp.add(k[1:-1].split("."), v)
raw_prompt.replace("{{#context#}}", context or "") raw_prompt = raw_prompt.replace("{{#context#}}", context or "")
prompt = vp.convert_template(raw_prompt).text prompt = vp.convert_template(raw_prompt).text
else: else:
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl) parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)

View File

@ -359,7 +359,7 @@ class LLMNode(BaseNode[LLMNodeData]):
return [] return []
raise ValueError(f"Invalid variable type: {type(variable)}") raise ValueError(f"Invalid variable type: {type(variable)}")
def _fetch_context(self, node_data: LLMNodeData) -> Generator[RunEvent, None, None]: def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled: if not node_data.context.enabled:
return return

View File

@ -2,6 +2,9 @@ from collections.abc import Mapping, Sequence
from os import path from os import path
from typing import Any from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.models import File, FileTransferMethod, FileType from core.file.models import File, FileTransferMethod, FileType
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@ -14,6 +17,8 @@ from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from enums import NodeType from enums import NodeType
from extensions.ext_database import db
from models import ToolFile
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
@ -167,45 +172,59 @@ class ToolNode(BaseNode[ToolNodeData]):
result = [] result = []
for response in tool_response: for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = response.message url = str(response.message) if response.message else None
ext = path.splitext(url)[1] ext = path.splitext(url)[1] if url else ".bin"
mimetype = response.meta.get("mime_type", "image/jpeg") tool_file_id = response.save_as or str(url).split("/")[-1].split(".")[0]
tool_file_id = response.save_as or url.split("/")[-1]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id with Session(db.engine) as session:
tool_file_id = str(url).split("/")[-1].split(".")[0] stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append( result.append(
File( File(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.IMAGE, type=FileType.IMAGE,
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=url, remote_url=url,
related_id=tool_file_id, related_id=tool_file.id,
filename=tool_file_id, filename=tool_file.name,
extension=ext, extension=ext,
mime_type=mimetype, mime_type=tool_file.mimetype,
size=tool_file.size,
) )
) )
elif response.type == ToolInvokeMessage.MessageType.BLOB: elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id # get tool file id
tool_file_id = str(response.message).split("/")[-1].split(".")[0] tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
result.append( result.append(
File( File(
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.IMAGE, type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE, transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id, related_id=tool_file.id,
filename=response.save_as, filename=tool_file.name,
extension=path.splitext(response.save_as)[1], extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get("mime_type", "application/octet-stream"), mime_type=tool_file.mimetype,
size=tool_file.size,
) )
) )
elif response.type == ToolInvokeMessage.MessageType.LINK: elif response.type == ToolInvokeMessage.MessageType.LINK:
url = str(response.message) url = str(response.message)
transfer_method = FileTransferMethod.TOOL_FILE transfer_method = FileTransferMethod.TOOL_FILE
mimetype = response.meta.get("mime_type", "application/octet-stream")
tool_file_id = url.split("/")[-1].split(".")[0] tool_file_id = url.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ValueError(f"tool file {tool_file_id} not exists")
if "." in url: if "." in url:
extension = "." + url.split("/")[-1].split(".")[1] extension = "." + url.split("/")[-1].split(".")[1]
else: else:
@ -215,10 +234,11 @@ class ToolNode(BaseNode[ToolNodeData]):
type=FileType(response.save_as), type=FileType(response.save_as),
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=url, remote_url=url,
filename=tool_file_id, filename=tool_file.name,
related_id=tool_file_id, related_id=tool_file.id,
extension=extension, extension=extension,
mime_type=mimetype, mime_type=tool_file.mimetype,
size=tool_file.size,
) )
result.append(file) result.append(file)