mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
move _get_file_type_by_mimetype to file_factory
This commit is contained in:
parent
7f73eabb41
commit
588dcf38c5
|
@ -168,17 +168,3 @@ def _to_url(f: File, /):
|
||||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||||
|
|
||||||
|
|
||||||
def get_file_type_by_mimetype(mime_type: str) -> FileType:
|
|
||||||
if "image" in mime_type:
|
|
||||||
file_type = FileType.IMAGE
|
|
||||||
elif "video" in mime_type:
|
|
||||||
file_type = FileType.VIDEO
|
|
||||||
elif "audio" in mime_type:
|
|
||||||
file_type = FileType.AUDIO
|
|
||||||
elif "text" in mime_type or "pdf" in mime_type:
|
|
||||||
file_type = FileType.DOCUMENT
|
|
||||||
else:
|
|
||||||
file_type = FileType.CUSTOM
|
|
||||||
return file_type
|
|
||||||
|
|
|
@ -10,7 +10,7 @@ from yarl import URL
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file.file_manager import get_file_type_by_mimetype
|
from core.file import FileType
|
||||||
from core.file.models import FileTransferMethod
|
from core.file.models import FileTransferMethod
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
||||||
|
@ -275,7 +275,16 @@ class ToolEngine:
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for message in tool_messages:
|
for message in tool_messages:
|
||||||
file_type = get_file_type_by_mimetype(message.mimetype)
|
if "image" in message.mimetype:
|
||||||
|
file_type = FileType.IMAGE
|
||||||
|
elif "video" in message.mimetype:
|
||||||
|
file_type = FileType.VIDEO
|
||||||
|
elif "audio" in message.mimetype:
|
||||||
|
file_type = FileType.AUDIO
|
||||||
|
elif "text" in message.mimetype or "pdf" in message.mimetype:
|
||||||
|
file_type = FileType.DOCUMENT
|
||||||
|
else:
|
||||||
|
file_type = FileType.CUSTOM
|
||||||
|
|
||||||
# extract tool file id from url
|
# extract tool file id from url
|
||||||
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from os import path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
@ -7,7 +6,6 @@ from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.file.file_manager import get_file_type_by_mimetype
|
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||||
from core.tools.tool_engine import ToolEngine
|
from core.tools.tool_engine import ToolEngine
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
|
@ -181,7 +179,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
for response in tool_response:
|
for response in tool_response:
|
||||||
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
if response.type in {ToolInvokeMessage.MessageType.IMAGE_LINK, ToolInvokeMessage.MessageType.IMAGE}:
|
||||||
url = str(response.message) if response.message else None
|
url = str(response.message) if response.message else None
|
||||||
ext = path.splitext(url)[1] if url else ".bin"
|
|
||||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||||
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||||
|
|
||||||
|
@ -203,7 +200,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
)
|
)
|
||||||
result.append(file)
|
result.append(file)
|
||||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||||
# get tool file id
|
|
||||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||||
|
@ -212,7 +208,6 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||||
mapping = {
|
mapping = {
|
||||||
"tool_file_id": tool_file_id,
|
"tool_file_id": tool_file_id,
|
||||||
"type": get_file_type_by_mimetype(response.meta.get("mime_type")),
|
|
||||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||||
}
|
}
|
||||||
file = file_factory.build_from_mapping(
|
file = file_factory.build_from_mapping(
|
||||||
|
@ -229,13 +224,8 @@ class ToolNode(BaseNode[ToolNodeData]):
|
||||||
tool_file = session.scalar(stmt)
|
tool_file = session.scalar(stmt)
|
||||||
if tool_file is None:
|
if tool_file is None:
|
||||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||||
if "." in url:
|
|
||||||
extension = "." + url.split("/")[-1].split(".")[1]
|
|
||||||
else:
|
|
||||||
extension = ".bin"
|
|
||||||
mapping = {
|
mapping = {
|
||||||
"tool_file_id": tool_file_id,
|
"tool_file_id": tool_file_id,
|
||||||
"type": get_file_type_by_mimetype(response.meta.get("mime_type")),
|
|
||||||
"transfer_method": transfer_method,
|
"transfer_method": transfer_method,
|
||||||
"url": url,
|
"url": url,
|
||||||
}
|
}
|
||||||
|
|
|
@ -180,6 +180,20 @@ def _get_remote_file_info(url: str):
|
||||||
return mime_type, filename, file_size
|
return mime_type, filename, file_size
|
||||||
|
|
||||||
|
|
||||||
|
def _get_file_type_by_mimetype(mime_type: str) -> FileType:
|
||||||
|
if "image" in mime_type:
|
||||||
|
file_type = FileType.IMAGE
|
||||||
|
elif "video" in mime_type:
|
||||||
|
file_type = FileType.VIDEO
|
||||||
|
elif "audio" in mime_type:
|
||||||
|
file_type = FileType.AUDIO
|
||||||
|
elif "text" in mime_type or "pdf" in mime_type:
|
||||||
|
file_type = FileType.DOCUMENT
|
||||||
|
else:
|
||||||
|
file_type = FileType.CUSTOM
|
||||||
|
return file_type
|
||||||
|
|
||||||
|
|
||||||
def _build_from_tool_file(
|
def _build_from_tool_file(
|
||||||
*,
|
*,
|
||||||
mapping: Mapping[str, Any],
|
mapping: Mapping[str, Any],
|
||||||
|
@ -199,12 +213,13 @@ def _build_from_tool_file(
|
||||||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||||
|
|
||||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||||
|
file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype))
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
filename=tool_file.name,
|
filename=tool_file.name,
|
||||||
type=FileType.value_of(mapping.get("type")),
|
type=file_type,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
remote_url=tool_file.original_url,
|
remote_url=tool_file.original_url,
|
||||||
related_id=tool_file.id,
|
related_id=tool_file.id,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user