mirror of
https://github.com/langgenius/dify.git
synced 2024-11-15 19:22:36 +08:00
refactor(core): Remove extra_config from File. (#10203)
This commit is contained in:
parent
78a380bcc4
commit
25ca0278dd
|
@ -30,6 +30,7 @@ from core.model_runtime.entities import (
|
|||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
@ -65,7 +66,7 @@ class BaseAgentRunner(AppRunner):
|
|||
prompt_messages: Optional[list[PromptMessage]] = None,
|
||||
variables_pool: Optional[ToolRuntimeVariablePool] = None,
|
||||
db_variables: Optional[ToolConversationVariables] = None,
|
||||
model_instance: ModelInstance = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.application_generate_entity = application_generate_entity
|
||||
|
@ -508,24 +509,27 @@ class BaseAgentRunner(AppRunner):
|
|||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||
if files:
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||
)
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
else:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
else:
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||
if not file_extra_config:
|
||||
return UserPromptMessage(content=message.query)
|
||||
|
||||
image_detail_config = file_extra_config.image_config.detail if file_extra_config.image_config else None
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||
)
|
||||
if not file_objs:
|
||||
return UserPromptMessage(content=message.query)
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file in file_objs:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
return UserPromptMessage(content=prompt_message_contents)
|
||||
|
|
|
@ -10,6 +10,7 @@ from core.model_runtime.entities import (
|
|||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
|
||||
|
@ -36,8 +37,24 @@ class CotChatAgentRunner(CotAgentRunner):
|
|||
if self.files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
|
|
@ -22,6 +22,7 @@ from core.model_runtime.entities import (
|
|||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
|
@ -397,8 +398,24 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
|||
if self.files:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
for file_obj in self.files:
|
||||
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, Optional
|
|||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.file import FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from models.model import AppMode
|
||||
|
||||
|
@ -211,7 +211,7 @@ class TracingConfigEntity(BaseModel):
|
|||
|
||||
|
||||
class AppAdditionalFeatures(BaseModel):
|
||||
file_upload: Optional[FileExtraConfig] = None
|
||||
file_upload: Optional[FileUploadConfig] = None
|
||||
opening_statement: Optional[str] = None
|
||||
suggested_questions: list[str] = []
|
||||
suggested_questions_after_answer: bool = False
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import FileExtraConfig
|
||||
from core.file import FileUploadConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
|
@ -29,15 +29,14 @@ class FileUploadConfigManager:
|
|||
if is_vision:
|
||||
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
||||
|
||||
return FileExtraConfig.model_validate(data)
|
||||
return FileUploadConfig.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
||||
def validate_and_set_defaults(cls, config: dict) -> tuple[dict, list[str]]:
|
||||
"""
|
||||
Validate and set defaults for file upload feature
|
||||
|
||||
:param config: app model config args
|
||||
:param is_vision: if True, the feature is vision feature
|
||||
"""
|
||||
if not config.get("file_upload"):
|
||||
config["file_upload"] = {}
|
||||
|
|
|
@ -52,9 +52,7 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config, is_vision=False
|
||||
)
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# opening_statement
|
||||
|
|
|
@ -26,7 +26,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, Conversation, EndUser, Message
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
@ -98,13 +97,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
|
@ -127,10 +123,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
|
|
@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -103,8 +102,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args.get("files") or []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
|
@ -112,8 +109,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
|
@ -135,10 +130,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
|
|
@ -2,12 +2,11 @@ from collections.abc import Mapping
|
|||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.app_config.entities import VariableEntityType
|
||||
from core.file import File, FileExtraConfig
|
||||
from core.file import File, FileUploadConfig
|
||||
from factories import file_factory
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
|
||||
class BaseAppGenerator:
|
||||
|
@ -16,8 +15,6 @@ class BaseAppGenerator:
|
|||
*,
|
||||
user_inputs: Optional[Mapping[str, Any]],
|
||||
app_config: "AppConfig",
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
) -> Mapping[str, Any]:
|
||||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
|
@ -34,9 +31,7 @@ class BaseAppGenerator:
|
|||
k: file_factory.build_from_mapping(
|
||||
mapping=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=FileExtraConfig(
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
|
@ -50,9 +45,7 @@ class BaseAppGenerator:
|
|||
k: file_factory.build_from_mappings(
|
||||
mappings=v,
|
||||
tenant_id=app_config.tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=FileExtraConfig(
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
|
|
|
@ -23,7 +23,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole
|
||||
from models.model import App, EndUser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -101,8 +100,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
# always enable retriever resource in debugger mode
|
||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
|
@ -110,8 +107,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
|
@ -133,10 +128,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
|||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
file_upload_config=file_extra_config,
|
||||
conversation_id=conversation.id if conversation else None,
|
||||
inputs=conversation.inputs
|
||||
if conversation
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
|
|
|
@ -22,7 +22,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Message
|
||||
from models.enums import CreatedByRole
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
@ -88,8 +87,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||
)
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
files = args["files"] if args.get("files") else []
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||
|
@ -97,8 +94,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
|
@ -110,7 +105,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
)
|
||||
|
||||
# get tracing instance
|
||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||
trace_manager = TraceQueueManager(app_model.id)
|
||||
|
||||
# init application generate entity
|
||||
|
@ -118,7 +112,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
model_conf=ModelConfigConverter.convert(app_config),
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
query=query,
|
||||
files=file_objs,
|
||||
user_id=user.id,
|
||||
|
@ -259,14 +254,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
override_model_config_dict["model"] = model_dict
|
||||
|
||||
# parse files
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||
if file_extra_config:
|
||||
file_objs = file_factory.build_from_mappings(
|
||||
mappings=message.message_files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -46,9 +46,7 @@ class WorkflowAppConfigManager(BaseAppConfigManager):
|
|||
related_config_keys = []
|
||||
|
||||
# file upload validation
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(
|
||||
config=config, is_vision=False
|
||||
)
|
||||
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||
related_config_keys.extend(current_related_config_keys)
|
||||
|
||||
# text_to_speech
|
||||
|
|
|
@ -25,7 +25,6 @@ from core.ops.ops_trace_manager import TraceQueueManager
|
|||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import Account, App, EndUser, Workflow
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -70,15 +69,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
):
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||
|
||||
# parse files
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
system_files = file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=app_model.tenant_id,
|
||||
user_id=user.id,
|
||||
role=role,
|
||||
config=file_extra_config,
|
||||
)
|
||||
|
||||
|
@ -100,7 +95,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||
application_generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
app_config=app_config,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||
file_upload_config=file_extra_config,
|
||||
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config),
|
||||
files=system_files,
|
||||
user_id=user.id,
|
||||
stream=stream,
|
||||
|
|
|
@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat
|
|||
from constants import UUID_NIL
|
||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||
from core.entities.provider_configuration import ProviderModelBundle
|
||||
from core.file.models import File
|
||||
from core.file import File, FileUploadConfig
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
@ -80,6 +80,7 @@ class AppGenerateEntity(BaseModel):
|
|||
|
||||
# app config
|
||||
app_config: AppConfig
|
||||
file_upload_config: Optional[FileUploadConfig] = None
|
||||
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[File]
|
||||
|
|
|
@ -2,13 +2,13 @@ from .constants import FILE_MODEL_IDENTITY
|
|||
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||
from .models import (
|
||||
File,
|
||||
FileExtraConfig,
|
||||
FileUploadConfig,
|
||||
ImageConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FileType",
|
||||
"FileExtraConfig",
|
||||
"FileUploadConfig",
|
||||
"FileTransferMethod",
|
||||
"FileBelongsTo",
|
||||
"File",
|
||||
|
|
|
@ -33,25 +33,28 @@ def get_attr(*, file: File, attr: FileAttribute):
|
|||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(f: File, /):
|
||||
def to_prompt_message_content(
|
||||
f: File,
|
||||
/,
|
||||
*,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||
):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent object.
|
||||
Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object.
|
||||
|
||||
This function takes a File object and converts it to an ImagePromptMessageContent
|
||||
object, which can be used as a prompt for image-based AI models.
|
||||
This function takes a File object and converts it to an appropriate PromptMessageContent
|
||||
object, which can be used as a prompt for image or audio-based AI models.
|
||||
|
||||
Args:
|
||||
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
||||
f (File): The File object to convert.
|
||||
detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts.
|
||||
If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
|
||||
Returns:
|
||||
ImagePromptMessageContent: An object containing the image data and detail level.
|
||||
Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level
|
||||
|
||||
Raises:
|
||||
ValueError: If the file is not an image or if the file data is missing.
|
||||
|
||||
Note:
|
||||
The detail level of the image prompt is determined by the file's extra_config.
|
||||
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
ValueError: If the file type is not supported or if required data is missing.
|
||||
"""
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
|
@ -60,12 +63,7 @@ def to_prompt_message_content(f: File, /):
|
|||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
|
||||
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
|
||||
detail = f._extra_config.image_config.detail
|
||||
else:
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
return ImagePromptMessageContent(data=data, detail=detail)
|
||||
return ImagePromptMessageContent(data=data, detail=image_detail_config)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _file_to_encoded_string(f)
|
||||
if f.extension is None:
|
||||
|
@ -78,7 +76,7 @@ def to_prompt_message_content(f: File, /):
|
|||
data = _to_base64_data_string(f)
|
||||
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
raise ValueError("file type f.type is not supported")
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
|
|
|
@ -21,7 +21,7 @@ class ImageConfig(BaseModel):
|
|||
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||
|
||||
|
||||
class FileExtraConfig(BaseModel):
|
||||
class FileUploadConfig(BaseModel):
|
||||
"""
|
||||
File Upload Entity.
|
||||
"""
|
||||
|
@ -46,7 +46,6 @@ class File(BaseModel):
|
|||
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||
mime_type: Optional[str] = None
|
||||
size: int = -1
|
||||
_extra_config: FileExtraConfig | None = None
|
||||
|
||||
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||
data = self.model_dump(mode="json")
|
||||
|
@ -107,34 +106,4 @@ class File(BaseModel):
|
|||
case FileTransferMethod.TOOL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
|
||||
# Validate the extra config.
|
||||
if not self._extra_config:
|
||||
return self
|
||||
|
||||
if self._extra_config.allowed_file_types:
|
||||
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
||||
raise ValueError(f"Invalid file type: {self.type}")
|
||||
|
||||
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
||||
raise ValueError(f"Invalid file extension: {self.extension}")
|
||||
|
||||
if (
|
||||
self._extra_config.allowed_upload_methods
|
||||
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
match self.type:
|
||||
case FileType.IMAGE:
|
||||
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||
if not self._extra_config.image_config:
|
||||
return self
|
||||
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
||||
if (
|
||||
self._extra_config.image_config.transfer_methods
|
||||
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
||||
):
|
||||
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||
|
||||
return self
|
||||
|
|
|
@ -81,15 +81,18 @@ class TokenBufferMemory:
|
|||
db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
|
||||
)
|
||||
|
||||
if workflow_run:
|
||||
if workflow_run and workflow_run.workflow:
|
||||
file_extra_config = FileUploadConfigManager.convert(
|
||||
workflow_run.workflow.features_dict, is_vision=False
|
||||
)
|
||||
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
if file_extra_config and app_record:
|
||||
file_objs = file_factory.build_from_message_files(
|
||||
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||
)
|
||||
if file_extra_config.image_config and file_extra_config.image_config.detail:
|
||||
detail = file_extra_config.image_config.detail
|
||||
else:
|
||||
file_objs = []
|
||||
|
||||
|
@ -98,12 +101,16 @@ class TokenBufferMemory:
|
|||
else:
|
||||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
if file_obj.type in {FileType.IMAGE, FileType.AUDIO}:
|
||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
||||
for file in file_objs:
|
||||
if file.type in {FileType.IMAGE, FileType.AUDIO}:
|
||||
prompt_message = file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=detail,
|
||||
)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from core.model_runtime.entities import (
|
|||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
|
@ -26,8 +27,13 @@ class AdvancedPromptTransform(PromptTransform):
|
|||
Advanced Prompt Transform for Workflow LLM Node.
|
||||
"""
|
||||
|
||||
def __init__(self, with_variable_tmpl: bool = False) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
with_variable_tmpl: bool = False,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW,
|
||||
) -> None:
|
||||
self.with_variable_tmpl = with_variable_tmpl
|
||||
self.image_detail_config = image_detail_config
|
||||
|
||||
def get_prompt(
|
||||
self,
|
||||
|
|
|
@ -1,19 +1,23 @@
|
|||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file import FileTransferMethod, FileType
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
from core.tools.provider.builtin.vectorizer.tools.vectorizer import VectorizerTool
|
||||
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
|
||||
from factories import file_factory
|
||||
|
||||
|
||||
class VectorizerProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
test_img = File(
|
||||
mapping = {
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
"type": FileType.IMAGE,
|
||||
"id": "test_id",
|
||||
"url": "https://cloud.dify.ai/logo/logo-site.png",
|
||||
}
|
||||
test_img = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id="__test_123",
|
||||
remote_url="https://cloud.dify.ai/logo/logo-site.png",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
)
|
||||
try:
|
||||
VectorizerTool().fork_tool_runtime(
|
||||
|
|
|
@ -13,6 +13,7 @@ from core.workflow.nodes.base import BaseNode
|
|||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
from core.workflow.utils import variable_template_parser
|
||||
from factories import file_factory
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import (
|
||||
|
@ -161,16 +162,15 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
mimetype=content_type,
|
||||
)
|
||||
|
||||
files.append(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file.id,
|
||||
filename=filename,
|
||||
extension=extension,
|
||||
mime_type=content_type,
|
||||
)
|
||||
mapping = {
|
||||
"tool_file_id": tool_file.id,
|
||||
"type": FileType.IMAGE.value,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE.value,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
|
|
@ -17,6 +17,7 @@ from core.workflow.nodes.base import BaseNode
|
|||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
@ -189,19 +190,17 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
result.append(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=ext,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": FileType.IMAGE,
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
result.append(file)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = str(response.message).split("/")[-1].split(".")[0]
|
||||
|
@ -209,19 +208,17 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
result.append(
|
||||
File(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file.id,
|
||||
filename=tool_file.name,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
raise ValueError(f"tool file {tool_file_id} not exists")
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": FileType.IMAGE,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
result.append(file)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
url = str(response.message)
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
@ -235,16 +232,15 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
extension = "." + url.split("/")[-1].split(".")[1]
|
||||
else:
|
||||
extension = ".bin"
|
||||
file = File(
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": FileType.IMAGE,
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType(response.save_as),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
filename=tool_file.name,
|
||||
related_id=tool_file.id,
|
||||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
)
|
||||
result.append(file)
|
||||
|
||||
|
|
|
@ -5,10 +5,10 @@ from collections.abc import Generator, Mapping, Sequence
|
|||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.app_config.entities import FileUploadConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File, FileTransferMethod, FileType, ImageConfig
|
||||
from core.file.models import File, FileTransferMethod, ImageConfig
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
|
@ -22,6 +22,7 @@ from core.workflow.nodes.base import BaseNode, BaseNodeData
|
|||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.llm import LLMNodeData
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
|
@ -271,19 +272,17 @@ class WorkflowEntry:
|
|||
for item in input_value:
|
||||
if isinstance(item, dict) and "type" in item and item["type"] == "image":
|
||||
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
|
||||
file = File(
|
||||
mapping = {
|
||||
"id": item.get("id"),
|
||||
"transfer_method": transfer_method,
|
||||
"upload_file_id": item.get("upload_file_id"),
|
||||
"url": item.get("url"),
|
||||
}
|
||||
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
remote_url=item.get("url")
|
||||
if transfer_method == FileTransferMethod.REMOTE_URL
|
||||
else None,
|
||||
related_id=item.get("upload_file_id")
|
||||
if transfer_method == FileTransferMethod.LOCAL_FILE
|
||||
else None,
|
||||
_extra_config=FileExtraConfig(
|
||||
image_config=ImageConfig(detail=detail) if detail else None
|
||||
),
|
||||
config=config,
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
|
|
|
@ -1,23 +1,21 @@
|
|||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
from models.enums import CreatedByRole
|
||||
|
||||
|
||||
def build_from_message_files(
|
||||
*,
|
||||
message_files: Sequence["MessageFile"],
|
||||
tenant_id: str,
|
||||
config: FileExtraConfig,
|
||||
config: FileUploadConfig,
|
||||
) -> Sequence[File]:
|
||||
results = [
|
||||
build_from_message_file(message_file=file, tenant_id=tenant_id, config=config)
|
||||
|
@ -31,7 +29,7 @@ def build_from_message_file(
|
|||
*,
|
||||
message_file: "MessageFile",
|
||||
tenant_id: str,
|
||||
config: FileExtraConfig,
|
||||
config: FileUploadConfig,
|
||||
):
|
||||
mapping = {
|
||||
"transfer_method": message_file.transfer_method,
|
||||
|
@ -43,8 +41,6 @@ def build_from_message_file(
|
|||
return build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=message_file.created_by,
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
@ -53,38 +49,30 @@ def build_from_mapping(
|
|||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
config: FileExtraConfig,
|
||||
):
|
||||
config: FileUploadConfig | None = None,
|
||||
) -> File:
|
||||
config = config or FileUploadConfig()
|
||||
|
||||
transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method"))
|
||||
match transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
file = _build_from_remote_url(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
file = _build_from_local_file(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=config,
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
file = _build_from_tool_file(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
config=config,
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
||||
|
||||
build_functions: dict[FileTransferMethod, Callable] = {
|
||||
FileTransferMethod.LOCAL_FILE: _build_from_local_file,
|
||||
FileTransferMethod.REMOTE_URL: _build_from_remote_url,
|
||||
FileTransferMethod.TOOL_FILE: _build_from_tool_file,
|
||||
}
|
||||
|
||||
build_func = build_functions.get(transfer_method)
|
||||
if not build_func:
|
||||
raise ValueError(f"Invalid file transfer method: {transfer_method}")
|
||||
|
||||
file = build_func(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
transfer_method=transfer_method,
|
||||
)
|
||||
|
||||
if not _is_file_valid_with_config(file=file, config=config):
|
||||
raise ValueError(f"File validation failed for file: {file.filename}")
|
||||
|
||||
return file
|
||||
|
||||
|
@ -92,10 +80,8 @@ def build_from_mapping(
|
|||
def build_from_mappings(
|
||||
*,
|
||||
mappings: Sequence[Mapping[str, Any]],
|
||||
config: FileExtraConfig | None,
|
||||
config: FileUploadConfig | None,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
) -> Sequence[File]:
|
||||
if not config:
|
||||
return []
|
||||
|
@ -104,8 +90,6 @@ def build_from_mappings(
|
|||
build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
config=config,
|
||||
)
|
||||
for mapping in mappings
|
||||
|
@ -128,31 +112,20 @@ def _build_from_local_file(
|
|||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
role: "CreatedByRole",
|
||||
config: FileExtraConfig,
|
||||
transfer_method: FileTransferMethod,
|
||||
):
|
||||
# check if the upload file exists.
|
||||
) -> File:
|
||||
file_type = FileType.value_of(mapping.get("type"))
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == mapping.get("upload_file_id"),
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
UploadFile.created_by == user_id,
|
||||
UploadFile.created_by_role == role,
|
||||
)
|
||||
if file_type == FileType.IMAGE:
|
||||
stmt = stmt.where(UploadFile.extension.in_(IMAGE_EXTENSIONS))
|
||||
elif file_type == FileType.VIDEO:
|
||||
stmt = stmt.where(UploadFile.extension.in_(VIDEO_EXTENSIONS))
|
||||
elif file_type == FileType.AUDIO:
|
||||
stmt = stmt.where(UploadFile.extension.in_(AUDIO_EXTENSIONS))
|
||||
elif file_type == FileType.DOCUMENT:
|
||||
stmt = stmt.where(UploadFile.extension.in_(DOCUMENT_EXTENSIONS))
|
||||
|
||||
row = db.session.scalar(stmt)
|
||||
|
||||
if row is None:
|
||||
raise ValueError("Invalid upload file")
|
||||
file = File(
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=row.name,
|
||||
extension="." + row.extension,
|
||||
|
@ -162,23 +135,37 @@ def _build_from_local_file(
|
|||
transfer_method=transfer_method,
|
||||
remote_url=row.source_url,
|
||||
related_id=mapping.get("upload_file_id"),
|
||||
_extra_config=config,
|
||||
size=row.size,
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def _build_from_remote_url(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
config: FileExtraConfig,
|
||||
transfer_method: FileTransferMethod,
|
||||
):
|
||||
) -> File:
|
||||
url = mapping.get("url")
|
||||
if not url:
|
||||
raise ValueError("Invalid file url")
|
||||
|
||||
mime_type, filename, file_size = _get_remote_file_info(url)
|
||||
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.value_of(mapping.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
mime_type=mime_type,
|
||||
extension=extension,
|
||||
size=file_size,
|
||||
)
|
||||
|
||||
|
||||
def _get_remote_file_info(url: str):
|
||||
mime_type = mimetypes.guess_type(url)[0] or ""
|
||||
file_size = -1
|
||||
filename = url.split("/")[-1].split("?")[0] or "unknown_file"
|
||||
|
@ -186,56 +173,34 @@ def _build_from_remote_url(
|
|||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
if content_disposition := resp.headers.get("Content-Disposition"):
|
||||
filename = content_disposition.split("filename=")[-1].strip('"')
|
||||
filename = str(content_disposition.split("filename=")[-1].strip('"'))
|
||||
file_size = int(resp.headers.get("Content-Length", file_size))
|
||||
mime_type = mime_type or str(resp.headers.get("Content-Type", ""))
|
||||
|
||||
# Determine file extension
|
||||
extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin"
|
||||
|
||||
if not mime_type:
|
||||
mime_type, _ = mimetypes.guess_type(url)
|
||||
file = File(
|
||||
id=mapping.get("id"),
|
||||
filename=filename,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.value_of(mapping.get("type")),
|
||||
transfer_method=transfer_method,
|
||||
remote_url=url,
|
||||
_extra_config=config,
|
||||
mime_type=mime_type,
|
||||
extension=extension,
|
||||
size=file_size,
|
||||
)
|
||||
return file
|
||||
return mime_type, filename, file_size
|
||||
|
||||
|
||||
def _build_from_tool_file(
|
||||
*,
|
||||
mapping: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
config: FileExtraConfig,
|
||||
transfer_method: FileTransferMethod,
|
||||
):
|
||||
) -> File:
|
||||
tool_file = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == mapping.get("tool_file_id"),
|
||||
ToolFile.tenant_id == tenant_id,
|
||||
ToolFile.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if tool_file is None:
|
||||
raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found")
|
||||
|
||||
path = tool_file.file_key
|
||||
if "." in path:
|
||||
extension = "." + path.split("/")[-1].split(".")[-1]
|
||||
else:
|
||||
extension = ".bin"
|
||||
file = File(
|
||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||
|
||||
return File(
|
||||
id=mapping.get("id"),
|
||||
tenant_id=tenant_id,
|
||||
filename=tool_file.name,
|
||||
|
@ -246,6 +211,21 @@ def _build_from_tool_file(
|
|||
extension=extension,
|
||||
mime_type=tool_file.mimetype,
|
||||
size=tool_file.size,
|
||||
_extra_config=config,
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool:
|
||||
if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM:
|
||||
return False
|
||||
|
||||
if config.allowed_extensions and file.extension not in config.allowed_extensions:
|
||||
return False
|
||||
|
||||
if config.allowed_upload_methods and file.transfer_method not in config.allowed_upload_methods:
|
||||
return False
|
||||
|
||||
if file.type == FileType.IMAGE and config.image_config:
|
||||
if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
@ -13,7 +13,7 @@ from sqlalchemy import Float, func, text
|
|||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
|
||||
from core.file import helpers as file_helpers
|
||||
from core.file.tool_file_parser import ToolFileParser
|
||||
from extensions.ext_database import db
|
||||
|
@ -949,9 +949,6 @@ class Message(db.Model):
|
|||
"type": message_file.type,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
elif message_file.transfer_method == "remote_url":
|
||||
if message_file.url is None:
|
||||
|
@ -964,9 +961,6 @@ class Message(db.Model):
|
|||
"url": message_file.url,
|
||||
},
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
elif message_file.transfer_method == "tool_file":
|
||||
if message_file.upload_file_id is None:
|
||||
|
@ -981,9 +975,6 @@ class Message(db.Model):
|
|||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=current_app.tenant_id,
|
||||
user_id=self.from_account_id or self.from_end_user_id or "",
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=FileExtraConfig(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
|
@ -13,7 +13,7 @@ from core.app.app_config.entities import (
|
|||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
|
||||
from core.file.models import FileExtraConfig
|
||||
from core.file.models import FileUploadConfig
|
||||
from core.helper import encrypter
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
@ -381,7 +381,7 @@ class WorkflowConverter:
|
|||
graph: dict,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template: PromptTemplateEntity,
|
||||
file_upload: Optional[FileExtraConfig] = None,
|
||||
file_upload: Optional[FileUploadConfig] = None,
|
||||
external_data_variable_node_mapping: dict[str, str] | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
|
|
|
@ -430,37 +430,3 @@ def test_multi_colons_parse(setup_http_mock):
|
|||
assert urlencode({"Redirect": "http://example2.com"}) in result.process_data.get("request", "")
|
||||
assert 'form-data; name="Redirect"\r\n\r\nhttp://example6.com' in result.process_data.get("request", "")
|
||||
# assert "http://example3.com" == resp.get("headers", {}).get("referer")
|
||||
|
||||
|
||||
def test_image_file(monkeypatch):
|
||||
from types import SimpleNamespace
|
||||
|
||||
monkeypatch.setattr(
|
||||
"core.tools.tool_file_manager.ToolFileManager.create_file_by_raw",
|
||||
lambda *args, **kwargs: SimpleNamespace(id="1"),
|
||||
)
|
||||
|
||||
node = init_http_node(
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "http",
|
||||
"desc": "",
|
||||
"method": "get",
|
||||
"url": "https://cloud.dify.ai/logo/logo-site.png",
|
||||
"authorization": {
|
||||
"type": "no-auth",
|
||||
"config": None,
|
||||
},
|
||||
"params": "",
|
||||
"headers": "",
|
||||
"body": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
assert result.outputs is not None
|
||||
resp = result.outputs
|
||||
assert len(resp.get("files", [])) == 1
|
||||
|
|
|
@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch
|
|||
import pytest
|
||||
|
||||
from core.app.app_config.entities import ModelConfigEntity
|
||||
from core.file import File, FileExtraConfig, FileTransferMethod, FileType, ImageConfig
|
||||
from core.file import File, FileTransferMethod, FileType, FileUploadConfig, ImageConfig
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
@ -134,7 +134,6 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
|||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image1.jpg",
|
||||
_extra_config=FileExtraConfig(image_config=ImageConfig(detail=ImagePromptMessageContent.DETAIL.HIGH)),
|
||||
)
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user