move _get_file_type_by_mimetype to file_factory

This commit is contained in:
hejl 2024-11-15 09:57:40 +08:00
parent 7f73eabb41
commit 588dcf38c5
4 changed files with 27 additions and 27 deletions

View File

@ -168,17 +168,3 @@ def _to_url(f: File, /):
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension) return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
else: else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}") raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def get_file_type_by_mimetype(mime_type: str) -> FileType:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type

View File

@ -10,7 +10,7 @@ from yarl import URL
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_manager import get_file_type_by_mimetype from core.file import FileType
from core.file.models import FileTransferMethod from core.file.models import FileTransferMethod
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
@ -275,7 +275,16 @@ class ToolEngine:
result = [] result = []
for message in tool_messages: for message in tool_messages:
file_type = get_file_type_by_mimetype(message.mimetype) if "image" in message.mimetype:
file_type = FileType.IMAGE
elif "video" in message.mimetype:
file_type = FileType.VIDEO
elif "audio" in message.mimetype:
file_type = FileType.AUDIO
elif "text" in message.mimetype or "pdf" in message.mimetype:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
# extract tool file id from url # extract tool file id from url
tool_file_id = message.url.split("/")[-1].split(".")[0] tool_file_id = message.url.split("/")[-1].split(".")[0]

View File

@ -1,5 +1,4 @@
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from os import path
from typing import Any from typing import Any
from sqlalchemy import select from sqlalchemy import select
@ -7,7 +6,6 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.file.file_manager import get_file_type_by_mimetype
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager from core.tools.tool_manager import ToolManager
@ -181,7 +179,6 @@ class ToolNode(BaseNode[ToolNodeData]):
for response in tool_response: for response in tool_response:
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}: if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
url = str(response.message) if response.message else None url = str(response.message) if response.message else None
ext = path.splitext(url)[1] if url else ".bin"
tool_file_id = str(url).split("/")[-1].split(".")[0] tool_file_id = str(url).split("/")[-1].split(".")[0]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
@ -203,7 +200,6 @@ class ToolNode(BaseNode[ToolNodeData]):
) )
result.append(file) result.append(file)
elif response.type == ToolInvokeMessage.MessageType.BLOB: elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = str(response.message).split("/")[-1].split(".")[0] tool_file_id = str(response.message).split("/")[-1].split(".")[0]
with Session(db.engine) as session: with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id) stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
@ -212,7 +208,6 @@ class ToolNode(BaseNode[ToolNodeData]):
raise ValueError(f"tool file {tool_file_id} not exists") raise ValueError(f"tool file {tool_file_id} not exists")
mapping = { mapping = {
"tool_file_id": tool_file_id, "tool_file_id": tool_file_id,
"type": get_file_type_by_mimetype(response.meta.get("mime_type")),
"transfer_method": FileTransferMethod.TOOL_FILE, "transfer_method": FileTransferMethod.TOOL_FILE,
} }
file = file_factory.build_from_mapping( file = file_factory.build_from_mapping(
@ -229,13 +224,8 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_file = session.scalar(stmt) tool_file = session.scalar(stmt)
if tool_file is None: if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist") raise ToolFileError(f"Tool file {tool_file_id} does not exist")
if "." in url:
extension = "." + url.split("/")[-1].split(".")[1]
else:
extension = ".bin"
mapping = { mapping = {
"tool_file_id": tool_file_id, "tool_file_id": tool_file_id,
"type": get_file_type_by_mimetype(response.meta.get("mime_type")),
"transfer_method": transfer_method, "transfer_method": transfer_method,
"url": url, "url": url,
} }

View File

@ -180,6 +180,20 @@ def _get_remote_file_info(url: str):
return mime_type, filename, file_size return mime_type, filename, file_size
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
if "image" in mime_type:
file_type = FileType.IMAGE
elif "video" in mime_type:
file_type = FileType.VIDEO
elif "audio" in mime_type:
file_type = FileType.AUDIO
elif "text" in mime_type or "pdf" in mime_type:
file_type = FileType.DOCUMENT
else:
file_type = FileType.CUSTOM
return file_type
def _build_from_tool_file( def _build_from_tool_file(
*, *,
mapping: Mapping[str, Any], mapping: Mapping[str, Any],
@ -199,12 +213,13 @@ def _build_from_tool_file(
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
return File( return File(
id=mapping.get("id"), id=mapping.get("id"),
tenant_id=tenant_id, tenant_id=tenant_id,
filename=tool_file.name, filename=tool_file.name,
type=FileType.value_of(mapping.get("type")), type=file_type,
transfer_method=transfer_method, transfer_method=transfer_method,
remote_url=tool_file.original_url, remote_url=tool_file.original_url,
related_id=tool_file.id, related_id=tool_file.id,