From 1450e5d5cb7038abd4c39cdb443eca3c62630c81 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 22 Oct 2024 17:26:00 +0800 Subject: [PATCH] feat: add supports for multimodal --- api/core/entities/parameter_entities.py | 2 ++ api/core/tools/entities/file_entities.py | 12 ++++++++++ api/core/tools/entities/tool_entities.py | 4 ++-- api/core/tools/plugin_tool/tool.py | 22 +++++++++++++++++++ api/core/tools/utils/message_transformer.py | 2 +- api/core/workflow/nodes/tool/tool_node.py | 12 ++++++++++ ...ase_max_length_of_builtin_tool_provider.py | 2 +- api/models/model.py | 2 +- api/models/tools.py | 4 ++++ api/services/account_service.py | 4 ++-- 10 files changed, 59 insertions(+), 7 deletions(-) create mode 100644 api/core/tools/entities/file_entities.py diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 0045fbf2b4..74d052ad11 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -8,6 +8,8 @@ class CommonParameterType(Enum): STRING = "string" NUMBER = "number" FILE = "file" + FILES = "files" + SYSTEM_FILES = "system-files" BOOLEAN = "boolean" APP_SELECTOR = "app-selector" MODEL_CONFIG = "model-config" diff --git a/api/core/tools/entities/file_entities.py b/api/core/tools/entities/file_entities.py new file mode 100644 index 0000000000..7482378028 --- /dev/null +++ b/api/core/tools/entities/file_entities.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from core.file.constants import FILE_MODEL_IDENTITY + + +class PluginFileEntity(BaseModel): + """ + File entity for plugin tool. + """ + + dify_model_identity: str = FILE_MODEL_IDENTITY + url: str diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 5f6c593cc0..dd1c7d8b7b 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -208,10 +208,10 @@ class ToolParameter(BaseModel): SELECT = CommonParameterType.SELECT.value SECRET_INPUT = CommonParameterType.SECRET_INPUT.value FILE = CommonParameterType.FILE.value - FILES = "files" + FILES = CommonParameterType.FILES.value # deprecated, should not use. - SYSTEM_FILES = "systme-files" + SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value def as_normal_type(self): if self in { diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index 7c3962a540..1ac8c75202 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -4,7 +4,9 @@ from typing import Any, Optional from core.plugin.manager.tool import PluginToolManager from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.file_entities import PluginFileEntity from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType +from models.model import File class PluginTool(Tool): @@ -29,6 +31,23 @@ class PluginTool(Tool): message_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: manager = PluginToolManager() + + # convert tool parameters with File type to PluginFileEntity + for parameter_name, parameter in tool_parameters.items(): + if isinstance(parameter, File): + url = parameter.generate_url() + if url is None: + raise ValueError(f"File {parameter.id} does not have a valid URL") + tool_parameters[parameter_name] = PluginFileEntity(url=url).model_dump() + elif isinstance(parameter, list) and all(isinstance(p, File) for p in parameter): + tool_parameters[parameter_name] = [] + for p in parameter: + assert isinstance(p, File) + url = p.generate_url() + if url is None: + raise ValueError(f"File {p.id} does not have a valid URL") + tool_parameters[parameter_name].append(PluginFileEntity(url=url)).model_dump() + return manager.invoke( tenant_id=self.tenant_id, user_id=user_id, @@ -36,6 +55,9 @@ class PluginTool(Tool): tool_name=self.entity.identity.name, credentials=self.runtime.credentials, tool_parameters=tool_parameters, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, ) def fork_tool_runtime(self, runtime: ToolRuntime) -> "PluginTool": diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 5bcd2ec61b..e4196095b0 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -69,7 +69,7 @@ class ToolFileMessageTransformer: raise ValueError("unexpected message type") # FIXME: should do a type check here. - assert isinstance(message.message, bytes) + assert isinstance(message.message.blob, bytes) file = ToolFileManager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index becf11c3d4..7082b1a168 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -138,7 +138,19 @@ class ToolNode(BaseNode[ToolNodeData]): parameter_value = segment_group.log if for_log else segment_group.text else: raise ValueError(f"unknown tool input type '{tool_input.type}'") + result[parameter_name] = parameter_value + # HACK: + result["file"] = File( + tenant_id="9a80db54-1557-46da-81fe-f0c4fd3df066", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url="https://example.com/image.png", + related_id="67f4eb5d-3419-4faf-b147-f77d8d69c6b6", + filename="image.png", + extension=".png", + mime_type="image/png", + ) return result diff --git a/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py b/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py index 4b16fe7f31..0d2db233c4 100644 --- a/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py +++ b/api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py @@ -12,7 +12,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. revision = 'ddcc8bbef391' -down_revision = 'd57ba9ebb251' +down_revision = 'bbadea11becb' branch_labels = None depends_on = None diff --git a/api/models/model.py b/api/models/model.py index 0da55cb9de..06947b043e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -38,7 +38,7 @@ class FileUploadConfig(BaseModel): number_limits: int = Field(default=0, gt=0, le=10) -class DifySetup(BaseModel): +class DifySetup(Base): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) diff --git a/api/models/tools.py b/api/models/tools.py index 8712f37946..869dd0201f 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -283,6 +283,10 @@ class ToolFile(Base): mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False) # original url original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True) + # name + name: Mapped[str] = mapped_column(default="") + # size + size: Mapped[int] = mapped_column(default=-1) @deprecated diff --git a/api/services/account_service.py b/api/services/account_service.py index 529b716773..fed9ae5a26 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -94,7 +94,7 @@ class AccountService: @staticmethod def load_user(user_id: str) -> None | Account: - account = Account.query.filter_by(id=user_id).first() + account = db.session.query(Account).filter_by(id=user_id).first() if not account: return None @@ -139,7 +139,7 @@ class AccountService: def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: """authenticate account with email and password""" - account = Account.query.filter_by(email=email).first() + account = db.session.query(Account).filter_by(email=email).first() if not account: raise AccountNotFoundError()