From 12a9e2972a097ccec5608cc8fe1a188af549de90 Mon Sep 17 00:00:00 2001 From: powerfool Date: Thu, 7 Nov 2024 13:22:09 +0800 Subject: [PATCH 1/9] Adjusted docker manifests and environment variables for OceanBase vector database (#10395) --- .gitignore | 1 + api/.env.example | 4 ++-- docker/.env.example | 6 +++--- docker/docker-compose.yaml | 12 +++++++++--- docker/volumes/oceanbase/init.d/vec_memory.sql | 1 + 5 files changed, 16 insertions(+), 8 deletions(-) create mode 100644 docker/volumes/oceanbase/init.d/vec_memory.sql diff --git a/.gitignore b/.gitignore index 60b5781733..1423bfee56 100644 --- a/.gitignore +++ b/.gitignore @@ -175,6 +175,7 @@ docker/volumes/pgvector/data/* docker/volumes/pgvecto_rs/data/* docker/volumes/couchbase/* docker/volumes/oceanbase/* +!docker/volumes/oceanbase/init.d docker/nginx/conf.d/default.conf docker/nginx/ssl/* diff --git a/api/.env.example b/api/.env.example index 6fc58263c4..a92490608f 100644 --- a/api/.env.example +++ b/api/.env.example @@ -121,7 +121,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm +# Vector database configuration, support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase VECTOR_STORE=weaviate # Weaviate configuration @@ -273,7 +273,7 @@ LINDORM_PASSWORD=admin OCEANBASE_VECTOR_HOST=127.0.0.1 OCEANBASE_VECTOR_PORT=2881 OCEANBASE_VECTOR_USER=root@test -OCEANBASE_VECTOR_PASSWORD= +OCEANBASE_VECTOR_PASSWORD=difyai123456 OCEANBASE_VECTOR_DATABASE=test OCEANBASE_MEMORY_LIMIT=6G diff --git a/docker/.env.example b/docker/.env.example index aa5e102bd0..9a178dc44c 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -374,7 +374,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`. +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `tidb_vector`, `oracle`, `tencent`, `elasticsearch`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`. VECTOR_STORE=weaviate # The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`. @@ -537,10 +537,10 @@ LINDORM_USERNAME=username LINDORM_PASSWORD=password # OceanBase Vector configuration, only available when VECTOR_STORE is `oceanbase` -OCEANBASE_VECTOR_HOST=oceanbase-vector +OCEANBASE_VECTOR_HOST=oceanbase OCEANBASE_VECTOR_PORT=2881 OCEANBASE_VECTOR_USER=root@test -OCEANBASE_VECTOR_PASSWORD= +OCEANBASE_VECTOR_PASSWORD=difyai123456 OCEANBASE_VECTOR_DATABASE=test OCEANBASE_MEMORY_LIMIT=6G diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index cdcc62e127..a7cb8576fd 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -266,8 +266,9 @@ x-shared-env: &shared-api-worker-env OCEANBASE_VECTOR_HOST: ${OCEANBASE_VECTOR_HOST:-http://oceanbase-vector} OCEANBASE_VECTOR_PORT: ${OCEANBASE_VECTOR_PORT:-2881} OCEANBASE_VECTOR_USER: ${OCEANBASE_VECTOR_USER:-root@test} - OCEANBASE_VECTOR_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-""} + OCEANBASE_VECTOR_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} OCEANBASE_VECTOR_DATABASE: ${OCEANBASE_VECTOR_DATABASE:-test} + OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} services: @@ -597,16 +598,21 @@ services: IS_PERSISTENT: ${CHROMA_IS_PERSISTENT:-TRUE} # OceanBase vector database - oceanbase-vector: + oceanbase: image: quay.io/oceanbase/oceanbase-ce:4.3.3.0-100000142024101215 profiles: - - oceanbase-vector + - oceanbase restart: always volumes: - ./volumes/oceanbase/data:/root/ob - ./volumes/oceanbase/conf:/root/.obd/cluster + - ./volumes/oceanbase/init.d:/root/boot/init.d environment: OB_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} + OB_SYS_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_TENANT_PASSWORD: ${OCEANBASE_VECTOR_PASSWORD:-difyai123456} + OB_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} + OB_SERVER_IP: '127.0.0.1' # Oracle vector database oracle: diff --git a/docker/volumes/oceanbase/init.d/vec_memory.sql b/docker/volumes/oceanbase/init.d/vec_memory.sql new file mode 100644 index 0000000000..f4c283fdf4 --- /dev/null +++ b/docker/volumes/oceanbase/init.d/vec_memory.sql @@ -0,0 +1 @@ +ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30; \ No newline at end of file From 1ccca7cc68dab0f2ae0fed48561eda46778ca094 Mon Sep 17 00:00:00 2001 From: luckylhb90 Date: Thu, 7 Nov 2024 08:55:19 +0300 Subject: [PATCH 2/9] fixed: web api remote urls error (#10383) Co-authored-by: hobo.l --- api/controllers/web/remote_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 0b8a586d0c..cf36ae302d 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -12,7 +12,7 @@ from services.file_service import FileService class RemoteFileInfoApi(WebApiResource): @marshal_with(remote_file_info_fields) - def get(self, url): + def get(self, app_model, end_user, url): decoded_url = urllib.parse.unquote(url) try: response = ssrf_proxy.head(decoded_url) From d3e9930235ac800a397b2554cb378c6792f882e3 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 7 Nov 2024 14:02:30 +0800 Subject: [PATCH 3/9] refactor(question_classifier): improve error handling with custom exceptions (#10365) --- api/core/workflow/nodes/question_classifier/exc.py | 6 ++++++ .../nodes/question_classifier/question_classifier_node.py | 6 ++++-- api/libs/json_in_md_parser.py | 2 +- 3 files changed, 11 insertions(+), 3 deletions(-) create mode 100644 api/core/workflow/nodes/question_classifier/exc.py diff --git a/api/core/workflow/nodes/question_classifier/exc.py b/api/core/workflow/nodes/question_classifier/exc.py new file mode 100644 index 0000000000..2c6354e2a7 --- /dev/null +++ b/api/core/workflow/nodes/question_classifier/exc.py @@ -0,0 +1,6 @@ +class QuestionClassifierNodeError(ValueError): + """Base class for QuestionClassifierNode errors.""" + + +class InvalidModelTypeError(QuestionClassifierNodeError): + """Raised when the model is not a Large Language Model.""" diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index ee160e7c69..0489020e5e 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -4,6 +4,7 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.llm_generator.output_parser.errors import OutputParserError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole @@ -24,6 +25,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models.workflow import WorkflowNodeExecutionStatus from .entities import QuestionClassifierNodeData +from .exc import InvalidModelTypeError from .template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, @@ -124,7 +126,7 @@ class QuestionClassifierNode(LLMNode): category_name = classes_map[category_id_result] category_id = category_id_result - except Exception: + except OutputParserError: logging.error(f"Failed to parse result text: {result_text}") try: process_data = { @@ -309,4 +311,4 @@ class QuestionClassifierNode(LLMNode): ) else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 9131408817..41c5d20c4b 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -9,6 +9,7 @@ def parse_json_markdown(json_string: str) -> dict: starts = ["```json", "```", "``", "`", "{"] ends = ["```", "``", "`", "}"] end_index = -1 + start_index = 0 for s in starts: start_index = json_string.find(s) if start_index != -1: @@ -24,7 +25,6 @@ def parse_json_markdown(json_string: str) -> dict: break if start_index != -1 and end_index != -1 and start_index < end_index: extracted_content = json_string[start_index:end_index].strip() - print("content:", extracted_content, start_index, end_index) parsed = json.loads(extracted_content) else: raise Exception("Could not find JSON block in the output.") From 35d3da96971a21ccf8367db610af8916b73ce352 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 7 Nov 2024 14:02:38 +0800 Subject: [PATCH 4/9] refactor(tool-node): introduce specific exceptions for tool node errors (#10357) --- api/core/workflow/nodes/tool/exc.py | 16 +++++++++++++++ api/core/workflow/nodes/tool/tool_node.py | 24 ++++++++++++++--------- 2 files changed, 31 insertions(+), 9 deletions(-) create mode 100644 api/core/workflow/nodes/tool/exc.py diff --git a/api/core/workflow/nodes/tool/exc.py b/api/core/workflow/nodes/tool/exc.py new file mode 100644 index 0000000000..7212e8bfc0 --- /dev/null +++ b/api/core/workflow/nodes/tool/exc.py @@ -0,0 +1,16 @@ +class ToolNodeError(ValueError): + """Base exception for tool node errors.""" + + pass + + +class ToolParameterError(ToolNodeError): + """Exception raised for errors in tool parameters.""" + + pass + + +class ToolFileError(ToolNodeError): + """Exception raised for errors related to tool files.""" + + pass diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 0994ccaedb..42e870c46c 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.file.models import File, FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.tool_engine import ToolEngine from core.tools.tool_manager import ToolManager @@ -15,12 +15,18 @@ from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResu from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus +from .entities import ToolNodeData +from .exc import ( + ToolFileError, + ToolNodeError, + ToolParameterError, +) + class ToolNode(BaseNode[ToolNodeData]): """ @@ -42,7 +48,7 @@ class ToolNode(BaseNode[ToolNodeData]): tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, @@ -75,7 +81,7 @@ class ToolNode(BaseNode[ToolNodeData]): workflow_call_depth=self.workflow_call_depth, thread_pool_id=self.thread_pool_id, ) - except Exception as e: + except ToolNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, @@ -133,13 +139,13 @@ class ToolNode(BaseNode[ToolNodeData]): if tool_input.type == "variable": variable = variable_pool.get(tool_input.value) if variable is None: - raise ValueError(f"variable {tool_input.value} not exists") + raise ToolParameterError(f"Variable {tool_input.value} does not exist") parameter_value = variable.value elif tool_input.type in {"mixed", "constant"}: segment_group = variable_pool.convert_template(str(tool_input.value)) parameter_value = segment_group.log if for_log else segment_group.text else: - raise ValueError(f"unknown tool input type '{tool_input.type}'") + raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") result[parameter_name] = parameter_value return result @@ -181,7 +187,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") result.append( File( @@ -203,7 +209,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") result.append( File( tenant_id=self.tenant_id, @@ -224,7 +230,7 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ValueError(f"tool file {tool_file_id} not exists") + raise ToolFileError(f"Tool file {tool_file_id} does not exist") if "." in url: extension = "." + url.split("/")[-1].split(".")[1] else: From 25785d8c3f6857a215481b09a38f5abecc99abe9 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 7 Nov 2024 14:02:46 +0800 Subject: [PATCH 5/9] refactor(knowledge-retrieval): improve error handling with custom exceptions (#10385) --- .../workflow/nodes/knowledge_retrieval/exc.py | 18 +++++++++++++ .../knowledge_retrieval_node.py | 27 ++++++++++++------- 2 files changed, 35 insertions(+), 10 deletions(-) create mode 100644 api/core/workflow/nodes/knowledge_retrieval/exc.py diff --git a/api/core/workflow/nodes/knowledge_retrieval/exc.py b/api/core/workflow/nodes/knowledge_retrieval/exc.py new file mode 100644 index 0000000000..0c3b6e86fa --- /dev/null +++ b/api/core/workflow/nodes/knowledge_retrieval/exc.py @@ -0,0 +1,18 @@ +class KnowledgeRetrievalNodeError(ValueError): + """Base class for KnowledgeRetrievalNode errors.""" + + +class ModelNotExistError(KnowledgeRetrievalNodeError): + """Raised when the model does not exist.""" + + +class ModelCredentialsNotInitializedError(KnowledgeRetrievalNodeError): + """Raised when the model credentials are not initialized.""" + + +class ModelNotSupportedError(KnowledgeRetrievalNodeError): + """Raised when the model is not supported.""" + + +class ModelQuotaExceededError(KnowledgeRetrievalNodeError): + """Raised when the model provider quota is exceeded.""" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2a5795a3ed..8c5a9b5ecb 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,7 +8,6 @@ from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -18,11 +17,19 @@ from core.variables import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment from models.workflow import WorkflowNodeExecutionStatus +from .entities import KnowledgeRetrievalNodeData +from .exc import ( + KnowledgeRetrievalNodeError, + ModelCredentialsNotInitializedError, + ModelNotExistError, + ModelNotSupportedError, + ModelQuotaExceededError, +) + logger = logging.getLogger(__name__) default_retrieval_model = { @@ -61,8 +68,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs ) - except Exception as e: - logger.exception("Error when running knowledge retrieval node") + except KnowledgeRetrievalNodeError as e: + logger.warning("Error when running knowledge retrieval node") return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: @@ -295,14 +302,14 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.") elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.") elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = node_data.single_retrieval_config.model.completion_params @@ -314,12 +321,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]): # get model mode model_mode = node_data.single_retrieval_config.model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise ModelNotExistError("LLM mode is required.") model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, From f8c958a409d5d3d27248ecc78d2a8d161896f9a4 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 7 Nov 2024 14:02:55 +0800 Subject: [PATCH 6/9] refactor(iteration): introduce specific exceptions for iteration errors (#10366) --- api/core/workflow/nodes/iteration/exc.py | 22 ++++++++++++++ .../nodes/iteration/iteration_node.py | 29 ++++++++++++------- 2 files changed, 41 insertions(+), 10 deletions(-) create mode 100644 api/core/workflow/nodes/iteration/exc.py diff --git a/api/core/workflow/nodes/iteration/exc.py b/api/core/workflow/nodes/iteration/exc.py new file mode 100644 index 0000000000..d9947e09bc --- /dev/null +++ b/api/core/workflow/nodes/iteration/exc.py @@ -0,0 +1,22 @@ +class IterationNodeError(ValueError): + """Base class for iteration node errors.""" + + +class IteratorVariableNotFoundError(IterationNodeError): + """Raised when the iterator variable is not found.""" + + +class InvalidIteratorValueError(IterationNodeError): + """Raised when the iterator value is invalid.""" + + +class StartNodeIdNotFoundError(IterationNodeError): + """Raised when the start node ID is not found.""" + + +class IterationGraphNotFoundError(IterationNodeError): + """Raised when the iteration graph is not found.""" + + +class IterationIndexNotFoundError(IterationNodeError): + """Raised when the iteration index is not found.""" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index d121b0530a..e1d2b88360 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -38,6 +38,15 @@ from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from models.workflow import WorkflowNodeExecutionStatus +from .exc import ( + InvalidIteratorValueError, + IterationGraphNotFoundError, + IterationIndexNotFoundError, + IterationNodeError, + IteratorVariableNotFoundError, + StartNodeIdNotFoundError, +) + if TYPE_CHECKING: from core.workflow.graph_engine.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -69,7 +78,7 @@ class IterationNode(BaseNode[IterationNodeData]): iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) if not iterator_list_segment: - raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found") + raise IteratorVariableNotFoundError(f"Iterator variable {self.node_data.iterator_selector} not found") if len(iterator_list_segment.value) == 0: yield RunCompletedEvent( @@ -83,14 +92,14 @@ class IterationNode(BaseNode[IterationNodeData]): iterator_list_value = iterator_list_segment.to_object() if not isinstance(iterator_list_value, list): - raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") + raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") inputs = {"iterator_selector": iterator_list_value} graph_config = self.graph_config if not self.node_data.start_node_id: - raise ValueError(f"field start_node_id in iteration {self.node_id} not found") + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") root_node_id = self.node_data.start_node_id @@ -98,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) if not iteration_graph: - raise ValueError("iteration graph not found") + raise IterationGraphNotFoundError("iteration graph not found") variable_pool = self.graph_runtime_state.variable_pool @@ -222,9 +231,9 @@ class IterationNode(BaseNode[IterationNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": jsonable_encoder(outputs)} ) ) - except Exception as e: + except IterationNodeError as e: # iteration run failed - logger.exception("Iteration run failed") + logger.warning("Iteration run failed") yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -272,7 +281,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) if not iteration_graph: - raise ValueError("iteration graph not found") + raise IterationGraphNotFoundError("iteration graph not found") for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items(): if sub_node_config.get("data", {}).get("iteration_id") != node_id: @@ -357,7 +366,7 @@ class IterationNode(BaseNode[IterationNodeData]): next_index = int(current_index) + 1 if current_index is None: - raise ValueError(f"iteration {self.node_id} current index not found") + raise IterationIndexNotFoundError(f"iteration {self.node_id} current index not found") for event in rst: if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: event.in_iteration_id = self.node_id @@ -484,8 +493,8 @@ class IterationNode(BaseNode[IterationNodeData]): pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, ) - except Exception as e: - logger.exception(f"Iteration run failed:{str(e)}") + except IterationNodeError as e: + logger.warning(f"Iteration run failed:{str(e)}") yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, From 823ae03a0884e1fc0c9fbcc17540f97bfce63d20 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Thu, 7 Nov 2024 14:35:58 +0800 Subject: [PATCH 7/9] fix(remote-files): fallback to get when remote server not support head method (#10370) --- api/controllers/console/error.py | 24 ++++++++++ .../console/{files/__init__.py => files.py} | 2 +- api/controllers/console/files/errors.py | 25 ----------- api/controllers/console/remote_files.py | 44 ++++++++++++------- api/controllers/web/remote_files.py | 43 ++++++++++-------- 5 files changed, 77 insertions(+), 61 deletions(-) rename api/controllers/console/{files/__init__.py => files.py} (99%) delete mode 100644 api/controllers/console/files/errors.py diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index ed6a99a017..e0630ca66c 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -62,3 +62,27 @@ class EmailSendIpLimitError(BaseHTTPException): error_code = "email_send_ip_limit" description = "Too many emails have been sent from this IP address recently. Please try again later." code = 429 + + +class FileTooLargeError(BaseHTTPException): + error_code = "file_too_large" + description = "File size exceeded. {message}" + code = 413 + + +class UnsupportedFileTypeError(BaseHTTPException): + error_code = "unsupported_file_type" + description = "File type not allowed." + code = 415 + + +class TooManyFilesError(BaseHTTPException): + error_code = "too_many_files" + description = "Only one file is allowed." + code = 400 + + +class NoFileUploadedError(BaseHTTPException): + error_code = "no_file_uploaded" + description = "Please upload your file." + code = 400 diff --git a/api/controllers/console/files/__init__.py b/api/controllers/console/files.py similarity index 99% rename from api/controllers/console/files/__init__.py rename to api/controllers/console/files.py index 6c7bd8acfd..946d3db37f 100644 --- a/api/controllers/console/files/__init__.py +++ b/api/controllers/console/files.py @@ -15,7 +15,7 @@ from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required from services.file_service import FileService -from .errors import ( +from .error import ( FileTooLargeError, NoFileUploadedError, TooManyFilesError, diff --git a/api/controllers/console/files/errors.py b/api/controllers/console/files/errors.py deleted file mode 100644 index 1654ef2cf4..0000000000 --- a/api/controllers/console/files/errors.py +++ /dev/null @@ -1,25 +0,0 @@ -from libs.exception import BaseHTTPException - - -class FileTooLargeError(BaseHTTPException): - error_code = "file_too_large" - description = "File size exceeded. {message}" - code = 413 - - -class UnsupportedFileTypeError(BaseHTTPException): - error_code = "unsupported_file_type" - description = "File type not allowed." - code = 415 - - -class TooManyFilesError(BaseHTTPException): - error_code = "too_many_files" - description = "Only one file is allowed." - code = 400 - - -class NoFileUploadedError(BaseHTTPException): - error_code = "no_file_uploaded" - description = "Please upload your file." - code = 400 diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 42d6e25416..9b899bef64 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,9 +1,11 @@ import urllib.parse from typing import cast +import httpx from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse +import services from controllers.common import helpers from core.file import helpers as file_helpers from core.helper import ssrf_proxy @@ -11,19 +13,25 @@ from fields.file_fields import file_fields_with_signed_url, remote_file_info_fie from models.account import Account from services.file_service import FileService +from .error import ( + FileTooLargeError, + UnsupportedFileTypeError, +) + class RemoteFileInfoApi(Resource): @marshal_with(remote_file_info_fields) def get(self, url): decoded_url = urllib.parse.unquote(url) - try: - response = ssrf_proxy.head(decoded_url) - return { - "file_type": response.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(response.headers.get("Content-Length", 0)), - } - except Exception as e: - return {"error": str(e)}, 400 + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + # failed back to get method + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return { + "file_type": resp.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(resp.headers.get("Content-Length", 0)), + } class RemoteFileUploadApi(Resource): @@ -35,17 +43,17 @@ class RemoteFileUploadApi(Resource): url = args["url"] - response = ssrf_proxy.head(url) - response.raise_for_status() + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3) + resp.raise_for_status() - file_info = helpers.guess_file_info_from_response(response) + file_info = helpers.guess_file_info_from_response(resp) if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): - return {"error": "File size exceeded"}, 400 + raise FileTooLargeError - response = ssrf_proxy.get(url) - response.raise_for_status() - content = response.content + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: user = cast(Account, current_user) @@ -56,8 +64,10 @@ class RemoteFileUploadApi(Resource): user=user, source_url=url, ) - except Exception as e: - return {"error": str(e)}, 400 + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() return { "id": upload_file.id, diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index cf36ae302d..d6b8eb2855 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,7 +1,9 @@ import urllib.parse +import httpx from flask_restful import marshal_with, reqparse +import services from controllers.common import helpers from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers @@ -9,19 +11,22 @@ from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields from services.file_service import FileService +from .error import FileTooLargeError, UnsupportedFileTypeError + class RemoteFileInfoApi(WebApiResource): @marshal_with(remote_file_info_fields) def get(self, app_model, end_user, url): decoded_url = urllib.parse.unquote(url) - try: - response = ssrf_proxy.head(decoded_url) - return { - "file_type": response.headers.get("Content-Type", "application/octet-stream"), - "file_length": int(response.headers.get("Content-Length", -1)), - } - except Exception as e: - return {"error": str(e)}, 400 + resp = ssrf_proxy.head(decoded_url) + if resp.status_code != httpx.codes.OK: + # failed back to get method + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return { + "file_type": resp.headers.get("Content-Type", "application/octet-stream"), + "file_length": int(resp.headers.get("Content-Length", -1)), + } class RemoteFileUploadApi(WebApiResource): @@ -33,28 +38,30 @@ class RemoteFileUploadApi(WebApiResource): url = args["url"] - response = ssrf_proxy.head(url) - response.raise_for_status() + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3) + resp.raise_for_status() - file_info = helpers.guess_file_info_from_response(response) + file_info = helpers.guess_file_info_from_response(resp) if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): - return {"error": "File size exceeded"}, 400 + raise FileTooLargeError - response = ssrf_proxy.get(url) - response.raise_for_status() - content = response.content + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: upload_file = FileService.upload_file( filename=file_info.filename, content=content, mimetype=file_info.mimetype, - user=end_user, # Use end_user instead of current_user + user=end_user, source_url=url, ) - except Exception as e: - return {"error": str(e)}, 400 + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError return { "id": upload_file.id, From 196684ca7e92050317e17cedfdd52e856ab28953 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 2 Nov 2024 21:47:01 +0800 Subject: [PATCH 8/9] refactor(core): Remove extra_config from File. --- api/core/agent/base_agent_runner.py | 46 ++--- api/core/agent/cot_chat_agent_runner.py | 21 ++- api/core/agent/fc_agent_runner.py | 21 ++- api/core/app/app_config/entities.py | 4 +- .../features/file_upload/manager.py | 7 +- .../apps/advanced_chat/app_config_manager.py | 4 +- .../app/apps/advanced_chat/app_generator.py | 6 +- api/core/app/apps/agent_chat/app_generator.py | 8 +- api/core/app/apps/base_app_generator.py | 13 +- api/core/app/apps/chat/app_generator.py | 7 +- api/core/app/apps/completion/app_generator.py | 11 +- .../app/apps/workflow/app_config_manager.py | 4 +- api/core/app/apps/workflow/app_generator.py | 7 +- api/core/app/entities/app_invoke_entities.py | 3 + api/core/file/__init__.py | 4 +- api/core/file/file_manager.py | 34 ++-- api/core/file/models.py | 33 +--- api/core/memory/token_buffer_memory.py | 15 +- api/core/prompt/advanced_prompt_transform.py | 8 +- .../provider/builtin/vectorizer/vectorizer.py | 16 +- api/core/workflow/nodes/http_request/node.py | 20 +-- api/core/workflow/nodes/tool/tool_node.py | 62 +++---- api/core/workflow/workflow_entry.py | 27 ++- api/factories/file_factory.py | 170 ++++++++---------- api/models/model.py | 11 +- api/services/workflow/workflow_converter.py | 4 +- .../prompt/test_advanced_prompt_transform.py | 3 +- 27 files changed, 260 insertions(+), 309 deletions(-) diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 507455c176..860ec5de0c 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -30,6 +30,7 @@ from core.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder @@ -65,7 +66,7 @@ class BaseAgentRunner(AppRunner): prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, - model_instance: ModelInstance = None, + model_instance: ModelInstance | None = None, ) -> None: self.tenant_id = tenant_id self.application_generate_entity = application_generate_entity @@ -508,24 +509,27 @@ class BaseAgentRunner(AppRunner): def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() - if files: - file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) - - if file_extra_config: - file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config - ) - else: - file_objs = [] - - if not file_objs: - return UserPromptMessage(content=message.query) - else: - prompt_message_contents: list[PromptMessageContent] = [] - prompt_message_contents.append(TextPromptMessageContent(data=message.query)) - for file_obj in file_objs: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) - - return UserPromptMessage(content=prompt_message_contents) - else: + if not files: return UserPromptMessage(content=message.query) + file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict()) + if not file_extra_config: + return UserPromptMessage(content=message.query) + + image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + + file_objs = file_factory.build_from_message_files( + message_files=files, tenant_id=self.tenant_id, config=file_extra_config + ) + if not file_objs: + return UserPromptMessage(content=message.query) + prompt_message_contents: list[PromptMessageContent] = [] + prompt_message_contents.append(TextPromptMessageContent(data=message.query)) + for file in file_objs: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) + return UserPromptMessage(content=prompt_message_contents) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 6261a9b12c..d8d047fe91 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -10,6 +10,7 @@ from core.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.model_runtime.utils.encoders import jsonable_encoder @@ -36,8 +37,24 @@ class CotChatAgentRunner(CotAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - for file_obj in self.files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 9083b4e85f..cd546dee12 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -22,6 +22,7 @@ from core.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine @@ -397,8 +398,24 @@ class FunctionCallAgentRunner(BaseAgentRunner): if self.files: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=query)) - for file_obj in self.files: - prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj)) + + # get image detail config + image_detail_config = ( + self.application_generate_entity.file_upload_config.image_config.detail + if ( + self.application_generate_entity.file_upload_config + and self.application_generate_entity.file_upload_config.image_config + ) + else None + ) + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW + for file in self.files: + prompt_message_contents.append( + file_manager.to_prompt_message_content( + file, + image_detail_config=image_detail_config, + ) + ) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 6c6e342a07..9b72452d7a 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,7 +4,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field, field_validator -from core.file import FileExtraConfig, FileTransferMethod, FileType +from core.file import FileTransferMethod, FileType, FileUploadConfig from core.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode @@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel): class AppAdditionalFeatures(BaseModel): - file_upload: Optional[FileExtraConfig] = None + file_upload: Optional[FileUploadConfig] = None opening_statement: Optional[str] = None suggested_questions: list[str] = [] suggested_questions_after_answer: bool = False diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index d0f75d0b75..a79ddf3ddf 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from typing import Any -from core.file import FileExtraConfig +from core.file import FileUploadConfig class FileUploadConfigManager: @@ -29,15 +29,14 @@ class FileUploadConfigManager: if is_vision: data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low") - return FileExtraConfig.model_validate(data) + return FileUploadConfig.model_validate(data) @classmethod - def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]: """ Validate and set defaults for file upload feature :param config: app model config args - :param is_vision: if True, the feature is vision feature """ if not config.get("file_upload"): config["file_upload"] = {} diff --git a/api/core/app/apps/advanced_chat/app_config_manager.py b/api/core/app/apps/advanced_chat/app_config_manager.py index b52f235849..cb606953cd 100644 --- a/api/core/app/apps/advanced_chat/app_config_manager.py +++ b/api/core/app/apps/advanced_chat/app_config_manager.py @@ -52,9 +52,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager): related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # opening_statement diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 39ab87c914..5323b7953a 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -26,7 +26,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models.account import Account -from models.enums import CreatedByRole from models.model import App, Conversation, EndUser, Message from models.workflow import Workflow @@ -98,13 +97,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER if file_extra_config: file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -130,7 +126,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index de12f5a441..5faaf04fbf 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser -from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -103,8 +102,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args.get("files") or [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -112,8 +109,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -135,10 +130,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index d8e38476c7..6e6da95401 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -2,12 +2,11 @@ from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional from core.app.app_config.entities import VariableEntityType -from core.file import File, FileExtraConfig +from core.file import File, FileUploadConfig from factories import file_factory if TYPE_CHECKING: from core.app.app_config.entities import AppConfig, VariableEntity - from models.enums import CreatedByRole class BaseAppGenerator: @@ -16,8 +15,6 @@ class BaseAppGenerator: *, user_inputs: Optional[Mapping[str, Any]], app_config: "AppConfig", - user_id: str, - role: "CreatedByRole", ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values @@ -34,9 +31,7 @@ class BaseAppGenerator: k: file_factory.build_from_mapping( mapping=v, tenant_id=app_config.tenant_id, - user_id=user_id, - role=role, - config=FileExtraConfig( + config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, @@ -50,9 +45,7 @@ class BaseAppGenerator: k: file_factory.build_from_mappings( mappings=v, tenant_id=app_config.tenant_id, - user_id=user_id, - role=role, - config=FileExtraConfig( + config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_extensions=entity_dictionary[k].allowed_file_extensions, allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 5c074f5306..844bbfb447 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models.account import Account -from models.enums import CreatedByRole from models.model import App, EndUser logger = logging.getLogger(__name__) @@ -101,8 +100,6 @@ class ChatAppGenerator(MessageBasedAppGenerator): # always enable retriever resource in debugger mode override_model_config_dict["retriever_resource"] = {"enabled": True} - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -110,8 +107,6 @@ class ChatAppGenerator(MessageBasedAppGenerator): file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -136,7 +131,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 46450d39c0..9c7f3e38fa 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -22,7 +22,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Message -from models.enums import CreatedByRole from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError @@ -88,8 +87,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): tenant_id=app_model.tenant_id, config=args.get("model_config") ) - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files files = args["files"] if args.get("files") else [] file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) @@ -97,8 +94,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): file_objs = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: @@ -110,7 +105,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) # get tracing instance - user_id = user.id if isinstance(user, Account) else user.session_id trace_manager = TraceQueueManager(app_model.id) # init application generate entity @@ -118,7 +112,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, user_id=user.id, @@ -259,14 +253,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator): override_model_config_dict["model"] = model_dict # parse files - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) if file_extra_config: file_objs = file_factory.build_from_mappings( mappings=message.message_files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) else: diff --git a/api/core/app/apps/workflow/app_config_manager.py b/api/core/app/apps/workflow/app_config_manager.py index 8b98e74b85..b0aa21c731 100644 --- a/api/core/app/apps/workflow/app_config_manager.py +++ b/api/core/app/apps/workflow/app_config_manager.py @@ -46,9 +46,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager): related_config_keys = [] # file upload validation - config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults( - config=config, is_vision=False - ) + config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config) related_config_keys.extend(current_related_config_keys) # text_to_speech diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index a865c8a68b..e08e62c3c8 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -25,7 +25,6 @@ from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Workflow -from models.enums import CreatedByRole logger = logging.getLogger(__name__) @@ -70,15 +69,11 @@ class WorkflowAppGenerator(BaseAppGenerator): ): files: Sequence[Mapping[str, Any]] = args.get("files") or [] - role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER - # parse files file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) system_files = file_factory.build_from_mappings( mappings=files, tenant_id=app_model.tenant_id, - user_id=user.id, - role=role, config=file_extra_config, ) @@ -100,7 +95,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity = WorkflowAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), + inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), files=system_files, user_id=user.id, stream=stream, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index f2eba29323..84b51583f8 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,6 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle +from core.file import FileUploadConfig from core.file.models import File from core.model_runtime.entities.model_entities import AIModelEntity from core.ops.ops_trace_manager import TraceQueueManager @@ -111,6 +112,8 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): app_config: EasyUIBasedAppConfig model_conf: ModelConfigWithCredentialsEntity + file_upload_config: Optional[FileUploadConfig] = None + query: Optional[str] = None # pydantic configs diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index bdaf8793fa..fe9e52258a 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -2,13 +2,13 @@ from .constants import FILE_MODEL_IDENTITY from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType from .models import ( File, - FileExtraConfig, + FileUploadConfig, ImageConfig, ) __all__ = [ "FileType", - "FileExtraConfig", + "FileUploadConfig", "FileTransferMethod", "FileBelongsTo", "File", diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index b69d7a74c0..f0aae6fa5d 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -33,25 +33,28 @@ def get_attr(*, file: File, attr: FileAttribute): raise ValueError(f"Invalid file attribute: {attr}") -def to_prompt_message_content(f: File, /): +def to_prompt_message_content( + f: File, + /, + *, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, +): """ - Convert a File object to an ImagePromptMessageContent object. + Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. - This function takes a File object and converts it to an ImagePromptMessageContent - object, which can be used as a prompt for image-based AI models. + This function takes a File object and converts it to an appropriate PromptMessageContent + object, which can be used as a prompt for image or audio-based AI models. Args: - file (File): The File object to convert. Must be of type FileType.IMAGE. + f (File): The File object to convert. + detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. + If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. Returns: - ImagePromptMessageContent: An object containing the image data and detail level. + Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level Raises: - ValueError: If the file is not an image or if the file data is missing. - - Note: - The detail level of the image prompt is determined by the file's extra_config. - If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW. + ValueError: If the file type is not supported or if required data is missing. """ match f.type: case FileType.IMAGE: @@ -60,19 +63,14 @@ def to_prompt_message_content(f: File, /): else: data = _to_base64_data_string(f) - if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail: - detail = f._extra_config.image_config.detail - else: - detail = ImagePromptMessageContent.DETAIL.LOW - - return ImagePromptMessageContent(data=data, detail=detail) + return ImagePromptMessageContent(data=data, detail=image_detail_config) case FileType.AUDIO: encoded_string = _file_to_encoded_string(f) if f.extension is None: raise ValueError("Missing file extension") return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) case _: - raise ValueError(f"file type {f.type} is not supported") + raise ValueError("file type f.type is not supported") def download(f: File, /): diff --git a/api/core/file/models.py b/api/core/file/models.py index 866ff3155b..0142893787 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -21,7 +21,7 @@ class ImageConfig(BaseModel): detail: ImagePromptMessageContent.DETAIL | None = None -class FileExtraConfig(BaseModel): +class FileUploadConfig(BaseModel): """ File Upload Entity. """ @@ -46,7 +46,6 @@ class File(BaseModel): extension: Optional[str] = Field(default=None, description="File extension, should contains dot") mime_type: Optional[str] = None size: int = -1 - _extra_config: FileExtraConfig | None = None def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") @@ -107,34 +106,4 @@ class File(BaseModel): case FileTransferMethod.TOOL_FILE: if not self.related_id: raise ValueError("Missing file related_id") - - # Validate the extra config. - if not self._extra_config: - return self - - if self._extra_config.allowed_file_types: - if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM: - raise ValueError(f"Invalid file type: {self.type}") - - if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions: - raise ValueError(f"Invalid file extension: {self.extension}") - - if ( - self._extra_config.allowed_upload_methods - and self.transfer_method not in self._extra_config.allowed_upload_methods - ): - raise ValueError(f"Invalid transfer method: {self.transfer_method}") - - match self.type: - case FileType.IMAGE: - # NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - if not self._extra_config.image_config: - return self - # TODO: skip check if transfer_methods is empty, because many test cases are not setting this field - if ( - self._extra_config.image_config.transfer_methods - and self.transfer_method not in self._extra_config.image_config.transfer_methods - ): - raise ValueError(f"Invalid transfer method: {self.transfer_method}") - return self diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index d92c36a2df..688fb4776a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -81,15 +81,18 @@ class TokenBufferMemory: db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() ) - if workflow_run: + if workflow_run and workflow_run.workflow: file_extra_config = FileUploadConfigManager.convert( workflow_run.workflow.features_dict, is_vision=False ) + detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: file_objs = file_factory.build_from_message_files( message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config ) + if file_extra_config.image_config and file_extra_config.image_config.detail: + detail = file_extra_config.image_config.detail else: file_objs = [] @@ -98,12 +101,16 @@ class TokenBufferMemory: else: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) - for file_obj in file_objs: - if file_obj.type in {FileType.IMAGE, FileType.AUDIO}: - prompt_message = file_manager.to_prompt_message_content(file_obj) + for file in file_objs: + if file.type in {FileType.IMAGE, FileType.AUDIO}: + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + else: prompt_messages.append(UserPromptMessage(content=message.query)) diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index bbd9531b19..0f3f824966 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -15,6 +15,7 @@ from core.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) +from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser @@ -26,8 +27,13 @@ class AdvancedPromptTransform(PromptTransform): Advanced Prompt Transform for Workflow LLM Node. """ - def __init__(self, with_variable_tmpl: bool = False) -> None: + def __init__( + self, + with_variable_tmpl: bool = False, + image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + ) -> None: self.with_variable_tmpl = with_variable_tmpl + self.image_detail_config = image_detail_config def get_prompt( self, diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 8140348723..211ec78f4d 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -1,19 +1,23 @@ from typing import Any -from core.file import File -from core.file.enums import FileTransferMethod, FileType +from core.file import FileTransferMethod, FileType from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from factories import file_factory class VectorizerProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: - test_img = File( + mapping = { + "transfer_method": FileTransferMethod.TOOL_FILE, + "type": FileType.IMAGE, + "id": "test_id", + "url": "https://cloud.dify.ai/logo/logo-site.png", + } + test_img = file_factory.build_from_mapping( + mapping=mapping, tenant_id="__test_123", - remote_url="https://cloud.dify.ai/logo/logo-site.png", - type=FileType.IMAGE, - transfer_method=FileTransferMethod.REMOTE_URL, ) try: VectorizerTool().fork_tool_runtime( diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 61c661e587..5b399bed63 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -13,6 +13,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser +from factories import file_factory from models.workflow import WorkflowNodeExecutionStatus from .entities import ( @@ -161,16 +162,15 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): mimetype=content_type, ) - files.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=filename, - extension=extension, - mime_type=content_type, - ) + mapping = { + "tool_file_id": tool_file.id, + "type": FileType.IMAGE.value, + "transfer_method": FileTransferMethod.TOOL_FILE.value, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + files.append(file) return files diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 42e870c46c..6870b7467d 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -17,6 +17,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from models import ToolFile from models.workflow import WorkflowNodeExecutionStatus @@ -189,19 +190,17 @@ class ToolNode(BaseNode[ToolNodeData]): if tool_file is None: raise ToolFileError(f"Tool file {tool_file_id} does not exist") - result.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - remote_url=url, - related_id=tool_file.id, - filename=tool_file.name, - extension=ext, - mime_type=tool_file.mimetype, - size=tool_file.size, - ) + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.BLOB: # get tool file id tool_file_id = str(response.message).split("/")[-1].split(".")[0] @@ -209,19 +208,17 @@ class ToolNode(BaseNode[ToolNodeData]): stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") - result.append( - File( - tenant_id=self.tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.TOOL_FILE, - related_id=tool_file.id, - filename=tool_file.name, - extension=path.splitext(response.save_as)[1], - mime_type=tool_file.mimetype, - size=tool_file.size, - ) + raise ValueError(f"tool file {tool_file_id} not exists") + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, ) + result.append(file) elif response.type == ToolInvokeMessage.MessageType.LINK: url = str(response.message) transfer_method = FileTransferMethod.TOOL_FILE @@ -235,16 +232,15 @@ class ToolNode(BaseNode[ToolNodeData]): extension = "." + url.split("/")[-1].split(".")[1] else: extension = ".bin" - file = File( + mapping = { + "tool_file_id": tool_file_id, + "type": FileType.IMAGE, + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=self.tenant_id, - type=FileType(response.save_as), - transfer_method=transfer_method, - remote_url=url, - filename=tool_file.name, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, ) result.append(file) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index eb812bad21..84b251223f 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,10 +5,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.app_config.entities import FileExtraConfig +from core.app.app_config.entities import FileUploadConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.models import File, FileTransferMethod, FileType, ImageConfig +from core.file.models import File, FileTransferMethod, ImageConfig from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError @@ -22,6 +22,7 @@ from core.workflow.nodes.base import BaseNode, BaseNodeData from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.llm import LLMNodeData from core.workflow.nodes.node_mapping import node_type_classes_mapping +from factories import file_factory from models.enums import UserFrom from models.workflow import ( Workflow, @@ -271,19 +272,17 @@ class WorkflowEntry: for item in input_value: if isinstance(item, dict) and "type" in item and item["type"] == "image": transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - file = File( + mapping = { + "id": item.get("id"), + "transfer_method": transfer_method, + "upload_file_id": item.get("upload_file_id"), + "url": item.get("url"), + } + config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None) + file = file_factory.build_from_mapping( + mapping=mapping, tenant_id=tenant_id, - type=FileType.IMAGE, - transfer_method=transfer_method, - remote_url=item.get("url") - if transfer_method == FileTransferMethod.REMOTE_URL - else None, - related_id=item.get("upload_file_id") - if transfer_method == FileTransferMethod.LOCAL_FILE - else None, - _extra_config=FileExtraConfig( - image_config=ImageConfig(detail=detail) if detail else None - ), + config=config, ) new_value.append(file) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 1066dc8862..738b2b3478 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,23 +1,21 @@ import mimetypes -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from typing import Any import httpx from sqlalchemy import select -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType +from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.helper import ssrf_proxy from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile -from models.enums import CreatedByRole def build_from_message_files( *, message_files: Sequence["MessageFile"], tenant_id: str, - config: FileExtraConfig, + config: FileUploadConfig, ) -> Sequence[File]: results = [ build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) @@ -31,7 +29,7 @@ def build_from_message_file( *, message_file: "MessageFile", tenant_id: str, - config: FileExtraConfig, + config: FileUploadConfig, ): mapping = { "transfer_method": message_file.transfer_method, @@ -43,8 +41,6 @@ def build_from_message_file( return build_from_mapping( mapping=mapping, tenant_id=tenant_id, - user_id=message_file.created_by, - role=CreatedByRole(message_file.created_by_role), config=config, ) @@ -53,38 +49,30 @@ def build_from_mapping( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - role: "CreatedByRole", - config: FileExtraConfig, -): + config: FileUploadConfig | None = None, +) -> File: + config = config or FileUploadConfig() + transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) - match transfer_method: - case FileTransferMethod.REMOTE_URL: - file = _build_from_remote_url( - mapping=mapping, - tenant_id=tenant_id, - config=config, - transfer_method=transfer_method, - ) - case FileTransferMethod.LOCAL_FILE: - file = _build_from_local_file( - mapping=mapping, - tenant_id=tenant_id, - user_id=user_id, - role=role, - config=config, - transfer_method=transfer_method, - ) - case FileTransferMethod.TOOL_FILE: - file = _build_from_tool_file( - mapping=mapping, - tenant_id=tenant_id, - user_id=user_id, - config=config, - transfer_method=transfer_method, - ) - case _: - raise ValueError(f"Invalid file transfer method: {transfer_method}") + + build_functions: dict[FileTransferMethod, Callable] = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + } + + build_func = build_functions.get(transfer_method) + if not build_func: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + ) + + if not _is_file_valid_with_config(file=file, config=config): + raise ValueError(f"File validation failed for file: {file.filename}") return file @@ -92,10 +80,8 @@ def build_from_mapping( def build_from_mappings( *, mappings: Sequence[Mapping[str, Any]], - config: FileExtraConfig | None, + config: FileUploadConfig | None, tenant_id: str, - user_id: str, - role: "CreatedByRole", ) -> Sequence[File]: if not config: return [] @@ -104,8 +90,6 @@ def build_from_mappings( build_from_mapping( mapping=mapping, tenant_id=tenant_id, - user_id=user_id, - role=role, config=config, ) for mapping in mappings @@ -128,31 +112,20 @@ def _build_from_local_file( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - role: "CreatedByRole", - config: FileExtraConfig, transfer_method: FileTransferMethod, -): - # check if the upload file exists. +) -> File: file_type = FileType.value_of(mapping.get("type")) stmt = select(UploadFile).where( UploadFile.id == mapping.get("upload_file_id"), UploadFile.tenant_id == tenant_id, - UploadFile.created_by == user_id, - UploadFile.created_by_role == role, ) - if file_type == FileType.IMAGE: - stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS)) - elif file_type == FileType.VIDEO: - stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS)) - elif file_type == FileType.AUDIO: - stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS)) - elif file_type == FileType.DOCUMENT: - stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS)) + row = db.session.scalar(stmt) + if row is None: raise ValueError("Invalid upload file") - file = File( + + return File( id=mapping.get("id"), filename=row.name, extension="." + row.extension, @@ -162,23 +135,37 @@ def _build_from_local_file( transfer_method=transfer_method, remote_url=row.source_url, related_id=mapping.get("upload_file_id"), - _extra_config=config, size=row.size, ) - return file def _build_from_remote_url( *, mapping: Mapping[str, Any], tenant_id: str, - config: FileExtraConfig, transfer_method: FileTransferMethod, -): +) -> File: url = mapping.get("url") if not url: raise ValueError("Invalid file url") + mime_type, filename, file_size = _get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + + return File( + id=mapping.get("id"), + filename=filename, + tenant_id=tenant_id, + type=FileType.value_of(mapping.get("type")), + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _get_remote_file_info(url: str): mime_type = mimetypes.guess_type(url)[0] or "" file_size = -1 filename = url.split("/")[-1].split("?")[0] or "unknown_file" @@ -186,56 +173,34 @@ def _build_from_remote_url( resp = ssrf_proxy.head(url, follow_redirects=True) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): - filename = content_disposition.split("filename=")[-1].strip('"') + filename = str(content_disposition.split("filename=")[-1].strip('"')) file_size = int(resp.headers.get("Content-Length", file_size)) mime_type = mime_type or str(resp.headers.get("Content-Type", "")) - # Determine file extension - extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" - - if not mime_type: - mime_type, _ = mimetypes.guess_type(url) - file = File( - id=mapping.get("id"), - filename=filename, - tenant_id=tenant_id, - type=FileType.value_of(mapping.get("type")), - transfer_method=transfer_method, - remote_url=url, - _extra_config=config, - mime_type=mime_type, - extension=extension, - size=file_size, - ) - return file + return mime_type, filename, file_size def _build_from_tool_file( *, mapping: Mapping[str, Any], tenant_id: str, - user_id: str, - config: FileExtraConfig, transfer_method: FileTransferMethod, -): +) -> File: tool_file = ( db.session.query(ToolFile) .filter( ToolFile.id == mapping.get("tool_file_id"), ToolFile.tenant_id == tenant_id, - ToolFile.user_id == user_id, ) .first() ) + if tool_file is None: raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") - path = tool_file.file_key - if "." in path: - extension = "." + path.split("/")[-1].split(".")[-1] - else: - extension = ".bin" - file = File( + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + + return File( id=mapping.get("id"), tenant_id=tenant_id, filename=tool_file.name, @@ -246,6 +211,21 @@ def _build_from_tool_file( extension=extension, mime_type=tool_file.mimetype, size=tool_file.size, - _extra_config=config, ) - return file + + +def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: + if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: + return False + + if config.allowed_extensions and file.extension not in config.allowed_extensions: + return False + + if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods: + return False + + if file.type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: + return False + + return True diff --git a/api/models/model.py b/api/models/model.py index d049cd373d..e909d53e3e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -13,7 +13,7 @@ from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column from configs import dify_config -from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType +from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from core.file.tool_file_parser import ToolFileParser from extensions.ext_database import db @@ -949,9 +949,6 @@ class Message(db.Model): "type": message_file.type, }, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) elif message_file.transfer_method == "remote_url": if message_file.url is None: @@ -964,9 +961,6 @@ class Message(db.Model): "url": message_file.url, }, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) elif message_file.transfer_method == "tool_file": if message_file.upload_file_id is None: @@ -981,9 +975,6 @@ class Message(db.Model): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, - user_id=self.from_account_id or self.from_end_user_id or "", - role=CreatedByRole(message_file.created_by_role), - config=FileExtraConfig(), ) else: raise ValueError( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 75c11afa94..90b5cc4836 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -13,7 +13,7 @@ from core.app.app_config.entities import ( from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from core.file.models import FileExtraConfig +from core.file.models import FileUploadConfig from core.helper import encrypter from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.utils.encoders import jsonable_encoder @@ -381,7 +381,7 @@ class WorkflowConverter: graph: dict, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, - file_upload: Optional[FileExtraConfig] = None, + file_upload: Optional[FileUploadConfig] = None, external_data_variable_node_mapping: dict[str, str] | None = None, ) -> dict: """ diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index ece2173090..7d19cff3e8 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.app.app_config.entities import ModelConfigEntity -from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig +from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -134,7 +134,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/image1.jpg", - _extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)), ) ] From 7988b3bb9d4f9ba6804472ef97ead5e100e28241 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sat, 2 Nov 2024 22:02:02 +0800 Subject: [PATCH 9/9] feat(app generator): add file_upload_config to app generation entities --- .../app/apps/advanced_chat/app_generator.py | 1 + api/core/app/apps/chat/app_generator.py | 1 + api/core/app/apps/completion/app_generator.py | 1 + api/core/app/apps/workflow/app_generator.py | 1 + api/core/app/entities/app_invoke_entities.py | 6 ++-- .../workflow/nodes/test_http.py | 34 ------------------- 6 files changed, 6 insertions(+), 38 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5323b7953a..0dd0ad1fd8 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -123,6 +123,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity = AdvancedChatAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 844bbfb447..0e71f380f7 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -128,6 +128,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 9c7f3e38fa..9b4db3902c 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -112,6 +112,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), query=query, files=file_objs, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index e08e62c3c8..b68afdf212 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -95,6 +95,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity = WorkflowAppGenerateEntity( task_id=str(uuid.uuid4()), app_config=app_config, + file_upload_config=file_extra_config, inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), files=system_files, user_id=user.id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 84b51583f8..31c3a996e1 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,8 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.file import FileUploadConfig -from core.file.models import File +from core.file import File, FileUploadConfig from core.model_runtime.entities.model_entities import AIModelEntity from core.ops.ops_trace_manager import TraceQueueManager @@ -81,6 +80,7 @@ class AppGenerateEntity(BaseModel): # app config app_config: AppConfig + file_upload_config: Optional[FileUploadConfig] = None inputs: Mapping[str, Any] files: Sequence[File] @@ -112,8 +112,6 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): app_config: EasyUIBasedAppConfig model_conf: ModelConfigWithCredentialsEntity - file_upload_config: Optional[FileUploadConfig] = None - query: Optional[str] = None # pydantic configs diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 0da6622658..9eea63f722 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -430,37 +430,3 @@ def test_multi_colons_parse(setup_http_mock): assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "") assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "") # assert "http://example3.com" == resp.get("headers", {}).get("referer") - - -def test_image_file(monkeypatch): - from types import SimpleNamespace - - monkeypatch.setattr( - "core.tools.tool_file_manager.ToolFileManager.create_file_by_raw", - lambda *args, **kwargs: SimpleNamespace(id="1"), - ) - - node = init_http_node( - config={ - "id": "1", - "data": { - "title": "http", - "desc": "", - "method": "get", - "url": "https://cloud.dify.ai/logo/logo-site.png", - "authorization": { - "type": "no-auth", - "config": None, - }, - "params": "", - "headers": "", - "body": None, - }, - } - ) - - result = node._run() - assert result.process_data is not None - assert result.outputs is not None - resp = result.outputs - assert len(resp.get("files", [])) == 1