From 4b2abf8ac233e23b678403f10032c1bea839b55a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Fri, 15 Nov 2024 10:38:12 +0800 Subject: [PATCH] fix: create_blob_message of tool will always create image type file (#10701) --- api/core/workflow/nodes/tool/tool_node.py | 9 --------- api/factories/file_factory.py | 17 ++++++++++++++++- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 6870b7467d..5560f26456 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -1,5 +1,4 @@ from collections.abc import Mapping, Sequence -from os import path from typing import Any from sqlalchemy import select @@ -180,7 +179,6 @@ class ToolNode(BaseNode[ToolNodeData]): for response in tool_response: if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: url = str(response.message) if response.message else None - ext = path.splitext(url)[1] if url else ".bin" tool_file_id = str(url).split("/")[-1].split(".")[0] transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) @@ -202,7 +200,6 @@ class ToolNode(BaseNode[ToolNodeData]): ) result.append(file) elif response.type == ToolInvokeMessage.MessageType.BLOB: - # get tool file id tool_file_id = str(response.message).split("/")[-1].split(".")[0] with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) @@ -211,7 +208,6 @@ class ToolNode(BaseNode[ToolNodeData]): raise ValueError(f"tool file {tool_file_id} not exists") mapping = { "tool_file_id": tool_file_id, - "type": FileType.IMAGE, "transfer_method": FileTransferMethod.TOOL_FILE, } file = file_factory.build_from_mapping( @@ -228,13 +224,8 @@ class ToolNode(BaseNode[ToolNodeData]): tool_file = session.scalar(stmt) if tool_file is None: raise ToolFileError(f"Tool file {tool_file_id} does not exist") - if "." in url: - extension = "." + url.split("/")[-1].split(".")[1] - else: - extension = ".bin" mapping = { "tool_file_id": tool_file_id, - "type": FileType.IMAGE, "transfer_method": transfer_method, "url": url, } diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 738b2b3478..15e9d7f34f 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -180,6 +180,20 @@ def _get_remote_file_info(url: str): return mime_type, filename, file_size +def _get_file_type_by_mimetype(mime_type: str) -> FileType: + if "image" in mime_type: + file_type = FileType.IMAGE + elif "video" in mime_type: + file_type = FileType.VIDEO + elif "audio" in mime_type: + file_type = FileType.AUDIO + elif "text" in mime_type or "pdf" in mime_type: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + return file_type + + def _build_from_tool_file( *, mapping: Mapping[str, Any], @@ -199,12 +213,13 @@ def _build_from_tool_file( raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype)) return File( id=mapping.get("id"), tenant_id=tenant_id, filename=tool_file.name, - type=FileType.value_of(mapping.get("type")), + type=file_type, transfer_method=transfer_method, remote_url=tool_file.original_url, related_id=tool_file.id,