From d25e79e794a12eff1d73ad1f4d3f18e5324a096c Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Thu, 14 Nov 2024 18:32:51 +0800 Subject: [PATCH] feat: support uploading images through plugin --- api/controllers/files/__init__.py | 2 +- api/controllers/files/upload.py | 64 ++++++++++++++++ api/controllers/inner_api/plugin/plugin.py | 89 +++++++++++++--------- api/controllers/inner_api/plugin/wraps.py | 27 ++++++- api/core/file/helpers.py | 35 +++++++++ api/core/plugin/entities/request.py | 9 +++ api/models/account.py | 1 + 7 files changed, 186 insertions(+), 41 deletions(-) create mode 100644 api/controllers/files/upload.py diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 97d5c3f88f..d4c3245708 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -6,4 +6,4 @@ bp = Blueprint("files", __name__) api = ExternalApi(bp) -from . import image_preview, tool_files +from . import image_preview, tool_files, upload diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py new file mode 100644 index 0000000000..6820cc7b75 --- /dev/null +++ b/api/controllers/files/upload.py @@ -0,0 +1,64 @@ +from flask import request +from flask_restful import Resource, marshal_with +from werkzeug.exceptions import Forbidden + +import services +from controllers.console.wraps import setup_required +from controllers.files import api +from controllers.files.error import UnsupportedFileTypeError +from controllers.inner_api.plugin.wraps import get_user +from controllers.service_api.app.error import FileTooLargeError +from core.file.helpers import verify_plugin_file_signature +from fields.file_fields import file_fields +from services.file_service import FileService + + +class PluginUploadFileApi(Resource): + @setup_required + @marshal_with(file_fields) + def post(self): + # get file from request + file = request.files["file"] + + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") + user_id = request.args.get("user_id") + user = get_user(user_id) + + filename = file.filename + mimetype = file.mimetype + + if not filename or not mimetype: + raise Forbidden("Invalid request.") + + if not timestamp or not nonce or not sign: + raise Forbidden("Invalid request.") + + if not verify_plugin_file_signature( + filename=filename, + mimetype=mimetype, + user_id=user_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ): + raise Forbidden("Invalid request.") + + try: + upload_file = FileService.upload_file( + filename=filename, + content=file.read(), + mimetype=mimetype, + user=user, + source=None, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + return upload_file, 201 + + +api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin") diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index e507c084a9..5ea2af8e84 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -2,8 +2,9 @@ from flask_restful import Resource from controllers.console.wraps import setup_required from controllers.inner_api import api -from controllers.inner_api.plugin.wraps import get_tenant, plugin_data +from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only +from core.file.helpers import get_signed_file_url_for_plugin from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse from core.plugin.backwards_invocation.encrypt import PluginEncrypter @@ -23,20 +24,22 @@ from core.plugin.entities.request import ( RequestInvokeTextEmbedding, RequestInvokeTool, RequestInvokeTTS, + RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType from libs.helper import compact_generate_response -from models.account import Tenant +from models.account import Account, Tenant +from models.model import EndUser class PluginInvokeLLMApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeLLM) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM): def generator(): - response = PluginModelBackwardsInvocation.invoke_llm(user_id, tenant_model, payload) + response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload) return PluginModelBackwardsInvocation.convert_to_event_stream(response) return compact_generate_response(generator()) @@ -45,13 +48,13 @@ class PluginInvokeLLMApi(Resource): class PluginInvokeTextEmbeddingApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeTextEmbedding) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding): try: return BaseBackwardsInvocationResponse( data=PluginModelBackwardsInvocation.invoke_text_embedding( - user_id=user_id, + user_id=user_model.id, tenant=tenant_model, payload=payload, ) @@ -63,13 +66,13 @@ class PluginInvokeTextEmbeddingApi(Resource): class PluginInvokeRerankApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeRerank) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank): try: return BaseBackwardsInvocationResponse( data=PluginModelBackwardsInvocation.invoke_rerank( - user_id=user_id, + user_id=user_model.id, tenant=tenant_model, payload=payload, ) @@ -81,12 +84,12 @@ class PluginInvokeRerankApi(Resource): class PluginInvokeTTSApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeTTS) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS): def generator(): response = PluginModelBackwardsInvocation.invoke_tts( - user_id=user_id, + user_id=user_model.id, tenant=tenant_model, payload=payload, ) @@ -98,13 +101,13 @@ class PluginInvokeTTSApi(Resource): class PluginInvokeSpeech2TextApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeSpeech2Text) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text): try: return BaseBackwardsInvocationResponse( data=PluginModelBackwardsInvocation.invoke_speech2text( - user_id=user_id, + user_id=user_model.id, tenant=tenant_model, payload=payload, ) @@ -116,13 +119,13 @@ class PluginInvokeSpeech2TextApi(Resource): class PluginInvokeModerationApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeModeration) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration): try: return BaseBackwardsInvocationResponse( data=PluginModelBackwardsInvocation.invoke_moderation( - user_id=user_id, + user_id=user_model.id, tenant=tenant_model, payload=payload, ) @@ -134,14 +137,14 @@ class PluginInvokeModerationApi(Resource): class PluginInvokeToolApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeTool) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool): def generator(): return PluginToolBackwardsInvocation.convert_to_event_stream( PluginToolBackwardsInvocation.invoke_tool( tenant_id=tenant_model.id, - user_id=user_id, + user_id=user_model.id, tool_type=ToolProviderType.value_of(payload.tool_type), provider=payload.provider, tool_name=payload.tool, @@ -155,14 +158,14 @@ class PluginInvokeToolApi(Resource): class PluginInvokeParameterExtractorNodeApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeParameterExtractorNode) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): try: return BaseBackwardsInvocationResponse( data=PluginNodeBackwardsInvocation.invoke_parameter_extractor( tenant_id=tenant_model.id, - user_id=user_id, + user_id=user_model.id, parameters=payload.parameters, model_config=payload.model, instruction=payload.instruction, @@ -176,14 +179,14 @@ class PluginInvokeParameterExtractorNodeApi(Resource): class PluginInvokeQuestionClassifierNodeApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): try: return BaseBackwardsInvocationResponse( data=PluginNodeBackwardsInvocation.invoke_question_classifier( tenant_id=tenant_model.id, - user_id=user_id, + user_id=user_model.id, query=payload.query, model_config=payload.model, classes=payload.classes, @@ -197,12 +200,12 @@ class PluginInvokeQuestionClassifierNodeApi(Resource): class PluginInvokeAppApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeApp) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeApp): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp): response = PluginAppBackwardsInvocation.invoke_app( app_id=payload.app_id, - user_id=user_id, + user_id=user_model.id, tenant_id=tenant_model.id, conversation_id=payload.conversation_id, query=payload.query, @@ -217,9 +220,9 @@ class PluginInvokeAppApi(Resource): class PluginInvokeEncryptApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeEncrypt) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt): """ encrypt or decrypt data """ @@ -234,13 +237,13 @@ class PluginInvokeEncryptApi(Resource): class PluginInvokeSummaryApi(Resource): @setup_required @plugin_inner_api_only - @get_tenant + @get_user_tenant @plugin_data(payload_type=RequestInvokeSummary) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSummary): + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary): try: return BaseBackwardsInvocationResponse( data=PluginModelBackwardsInvocation.invoke_summary( - user_id=user_id, + user_id=user_model.id, tenant=tenant_model, payload=payload, ) @@ -249,6 +252,17 @@ class PluginInvokeSummaryApi(Resource): return BaseBackwardsInvocationResponse(error=str(e)).model_dump() +class PluginUploadFileRequestApi(Resource): + @setup_required + @plugin_inner_api_only + @get_user_tenant + @plugin_data(payload_type=RequestRequestUploadFile) + def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): + # generate signed url + url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, user_model.id) + return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() + + api.add_resource(PluginInvokeLLMApi, "/invoke/llm") api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") @@ -261,3 +275,4 @@ api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classi api.add_resource(PluginInvokeAppApi, "/invoke/app") api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") api.add_resource(PluginInvokeSummaryApi, "/invoke/summary") +api.add_resource(PluginUploadFileRequestApi, "/upload/file/request") diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 07249013f9..f8a2c71e80 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -7,10 +7,31 @@ from flask_restful import reqparse from pydantic import BaseModel from extensions.ext_database import db -from models.account import Tenant +from models.account import Account, Tenant +from models.model import EndUser +from services.account_service import AccountService -def get_tenant(view: Optional[Callable] = None): +def get_user(user_id: str | None) -> Account | EndUser: + try: + if not user_id: + user_id = "DEFAULT-USER" + + if user_id == "DEFAULT-USER": + user_model = db.session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first() + else: + user_model = AccountService.load_user(user_id) + if not user_model: + user_model = db.session.query(EndUser).filter(EndUser.id == user_id).first() + if not user_model: + raise ValueError("user not found") + except Exception: + raise ValueError("user not found") + + return user_model + + +def get_user_tenant(view: Optional[Callable] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): @@ -42,7 +63,7 @@ def get_tenant(view: Optional[Callable] = None): raise ValueError("tenant not found") kwargs["tenant_model"] = tenant_model - kwargs["user_id"] = user_id + kwargs["user_model"] = get_user(user_id) return view_func(*args, **kwargs) diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 12123cf3f7..34fcdef5f6 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -20,6 +20,41 @@ def get_signed_file_url(upload_file_id: str) -> str: return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" +def get_signed_file_url_for_plugin(filename: str, mimetype: str, user_id: str) -> str: + url = f"{dify_config.FILES_URL}/files/upload/for-plugin" + + if user_id is None: + user_id = "DEFAULT-USER" + + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + key = dify_config.SECRET_KEY.encode() + msg = f"upload|{filename}|{mimetype}|{user_id}|{timestamp}|{nonce}" + sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}" + + +def verify_plugin_file_signature( + *, filename: str, mimetype: str, user_id: str | None, timestamp: str, nonce: str, sign: str +) -> bool: + if user_id is None: + user_id = "DEFAULT-USER" + + data_to_sign = f"upload|{filename}|{mimetype}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + # verify signature + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" secret_key = dify_config.SECRET_KEY.encode() diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index d98b80ee43..9a0e569d4d 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -195,3 +195,12 @@ class RequestInvokeSummary(BaseModel): text: str instruction: str + + +class RequestRequestUploadFile(BaseModel): + """ + Request to upload file + """ + + filename: str + mimetype: str diff --git a/api/models/account.py b/api/models/account.py index 99464865dd..0a2c492ae6 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -55,6 +55,7 @@ class Account(UserMixin, Base): tenant.current_role = ta.role else: tenant = None + self._current_tenant = tenant @property