mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
refactor(api): update file preview handling and support for audio files
- Deprecated image-preview functionality in favor of file-preview for more uniform file handling. - Introduced support for handling audio files along with images. - Improved methods to handle different file types and transfer methods in a modular way. - Simplified code by removing unnecessary checks for file type and related IDs. - Updated OpenAI package to version 1.52.0 for improved compatibility. - Added `jiter` dependency for JSON parsing.
This commit is contained in:
parent
af888c1b57
commit
c03407616a
|
@ -10,6 +10,10 @@ from services.file_service import FileService
|
|||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
"""
|
||||
Deprecated
|
||||
"""
|
||||
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import base64
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
from core.file import file_repository
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models import UploadFile
|
||||
|
||||
from . import helpers
|
||||
from .enums import FileAttribute
|
||||
|
@ -12,7 +13,7 @@ from .models import File, FileTransferMethod, FileType
|
|||
from .tool_file_parser import ToolFileParser
|
||||
|
||||
|
||||
def get_attr(*, file: "File", attr: "FileAttribute"):
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
case FileAttribute.TYPE:
|
||||
return file.type.value
|
||||
|
@ -32,7 +33,7 @@ def get_attr(*, file: "File", attr: "FileAttribute"):
|
|||
raise ValueError(f"Invalid file attribute: {attr}")
|
||||
|
||||
|
||||
def to_prompt_message_content(file: "File", /):
|
||||
def to_prompt_message_content(f: File, /):
|
||||
"""
|
||||
Convert a File object to an ImagePromptMessageContent object.
|
||||
|
||||
|
@ -52,34 +53,34 @@ def to_prompt_message_content(file: "File", /):
|
|||
The detail level of the image prompt is determined by the file's extra_config.
|
||||
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||
"""
|
||||
if file.type != FileType.IMAGE:
|
||||
raise ValueError("Only image file can convert to prompt message content")
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
data = _to_url(f)
|
||||
else:
|
||||
data = _to_base64_data_string(f)
|
||||
|
||||
url_or_b64_data = _get_url_or_b64_data(file=file)
|
||||
if url_or_b64_data is None:
|
||||
raise ValueError("Missing file data")
|
||||
|
||||
# decide the detail of image prompt message content
|
||||
if file._extra_config and file._extra_config.image_config and file._extra_config.image_config.detail:
|
||||
detail = file._extra_config.image_config.detail
|
||||
if f._extra_config and f._extra_config.image_config and f._extra_config.image_config.detail:
|
||||
detail = f._extra_config.image_config.detail
|
||||
else:
|
||||
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
return ImagePromptMessageContent(data=url_or_b64_data, detail=detail)
|
||||
return ImagePromptMessageContent(data=data, detail=detail)
|
||||
case FileType.AUDIO:
|
||||
encoded_string = _file_to_encoded_string(f)
|
||||
if f.extension is None:
|
||||
raise ValueError("Missing file extension")
|
||||
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def download(*, upload_file_id: str, tenant_id: str):
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
raise ValueError("upload file not found")
|
||||
|
||||
return _download(upload_file.key)
|
||||
def download(f: File, /):
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
return _download_file_content(upload_file.key)
|
||||
|
||||
|
||||
def _download(path: str, /):
|
||||
def _download_file_content(path: str, /):
|
||||
"""
|
||||
Download and return the contents of a file as bytes.
|
||||
|
||||
|
@ -100,37 +101,56 @@ def _download(path: str, /):
|
|||
return data
|
||||
|
||||
|
||||
def _get_base64(*, upload_file_id: str, tenant_id: str) -> str | None:
|
||||
upload_file = (
|
||||
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
)
|
||||
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
data = _download(upload_file.key)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
def _get_encoded_string(f: File, /):
|
||||
match f.transfer_method:
|
||||
case FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url)
|
||||
response.raise_for_status()
|
||||
content = response.content
|
||||
encoded_string = base64.b64encode(content).decode("utf-8")
|
||||
return encoded_string
|
||||
case FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = file_repository.get_upload_file(session=db.session(), file=f)
|
||||
data = _download_file_content(upload_file.key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||
return encoded_string
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
tool_file = file_repository.get_tool_file(session=db.session(), file=f)
|
||||
data = _download_file_content(tool_file.file_key)
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def _get_url_or_b64_data(file: "File"):
|
||||
if file.type == FileType.IMAGE:
|
||||
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return file.remote_url
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if file.related_id is None:
|
||||
def _to_base64_data_string(f: File, /):
|
||||
encoded_string = _get_encoded_string(f)
|
||||
return f"data:{f.mime_type};base64,{encoded_string}"
|
||||
|
||||
|
||||
def _file_to_encoded_string(f: File, /):
|
||||
match f.type:
|
||||
case FileType.IMAGE:
|
||||
return _to_base64_data_string(f)
|
||||
case FileType.AUDIO:
|
||||
return _get_encoded_string(f)
|
||||
case _:
|
||||
raise ValueError(f"file type {f.type} is not supported")
|
||||
|
||||
|
||||
def _to_url(f: File, /):
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
if f.remote_url is None:
|
||||
raise ValueError("Missing file remote_url")
|
||||
return f.remote_url
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if f.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
|
||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||
return helpers.get_signed_image_url(upload_file_id=file.related_id)
|
||||
return _get_base64(upload_file_id=file.related_id, tenant_id=file.tenant_id)
|
||||
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return helpers.get_signed_file_url(upload_file_id=f.related_id)
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
# add sign url
|
||||
if file.related_id is None or file.extension is None:
|
||||
if f.related_id is None or f.extension is None:
|
||||
raise ValueError("Missing file related_id or extension")
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=file.related_id, extension=file.extension
|
||||
)
|
||||
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
|
32
api/core/file/file_repository.py
Normal file
32
api/core/file/file_repository.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
from .models import File
|
||||
|
||||
|
||||
def get_upload_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(UploadFile).filter(
|
||||
UploadFile.id == file.related_id,
|
||||
UploadFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"upload file {file.related_id} not found")
|
||||
return record
|
||||
|
||||
|
||||
def get_tool_file(*, session: Session, file: File):
|
||||
if file.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
stmt = select(ToolFile).filter(
|
||||
ToolFile.id == file.related_id,
|
||||
ToolFile.tenant_id == file.tenant_id,
|
||||
)
|
||||
record = session.scalar(stmt)
|
||||
if not record:
|
||||
raise ValueError(f"tool file {file.related_id} not found")
|
||||
return record
|
|
@ -7,19 +7,6 @@ import time
|
|||
from configs import dify_config
|
||||
|
||||
|
||||
def get_signed_image_url(upload_file_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/image-preview"
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
key = dify_config.SECRET_KEY.encode()
|
||||
msg = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
|
||||
def get_signed_file_url(upload_file_id: str) -> str:
|
||||
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ class File(BaseModel):
|
|||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if self.related_id is None:
|
||||
raise ValueError("Missing file related_id")
|
||||
return helpers.get_signed_image_url(upload_file_id=self.related_id)
|
||||
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import FileType, file_manager
|
||||
from core.file import file_manager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
|
@ -98,8 +98,6 @@ class TokenBufferMemory:
|
|||
prompt_message_contents: list[PromptMessageContent] = []
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||
for file_obj in file_objs:
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
continue
|
||||
prompt_message = file_manager.to_prompt_message_content(file_obj)
|
||||
prompt_message_contents.append(prompt_message)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from .message_entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContent,
|
||||
|
@ -33,4 +34,5 @@ __all__ = [
|
|||
"LLMResult",
|
||||
"LLMResultChunk",
|
||||
"LLMResultChunkDelta",
|
||||
"AudioPromptMessageContent",
|
||||
]
|
||||
|
|
|
@ -2,7 +2,7 @@ from abc import ABC
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class PromptMessageRole(Enum):
|
||||
|
@ -55,6 +55,7 @@ class PromptMessageContentType(Enum):
|
|||
|
||||
TEXT = "text"
|
||||
IMAGE = "image"
|
||||
AUDIO = "audio"
|
||||
|
||||
|
||||
class PromptMessageContent(BaseModel):
|
||||
|
@ -74,6 +75,12 @@ class TextPromptMessageContent(PromptMessageContent):
|
|||
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
||||
|
||||
|
||||
class AudioPromptMessageContent(PromptMessageContent):
|
||||
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
||||
data: str = Field(..., description="Base64 encoded audio data")
|
||||
format: str = Field(..., description="Audio format")
|
||||
|
||||
|
||||
class ImagePromptMessageContent(PromptMessageContent):
|
||||
"""
|
||||
Model class for image prompt message content.
|
||||
|
|
|
@ -107,7 +107,16 @@ class LargeLanguageModel(AIModel):
|
|||
callbacks=callbacks,
|
||||
)
|
||||
else:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
result = self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
model=model,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
- gpt-4o-audio-preview
|
||||
- gpt-4
|
||||
- gpt-4o
|
||||
- gpt-4o-2024-05-13
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
model: gpt-4o-audio-preview
|
||||
label:
|
||||
zh_Hans: gpt-4o-audio-preview
|
||||
en_US: gpt-4o-audio-preview
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
- vision
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: Response Format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '5.00'
|
||||
output: '15.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
import tiktoken
|
||||
from openai import OpenAI, Stream
|
||||
|
@ -11,9 +11,9 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
|||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
|
@ -23,6 +23,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
@ -613,6 +614,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
# clear illegal prompt messages
|
||||
prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages)
|
||||
|
||||
# o1 compatibility
|
||||
block_as_stream = False
|
||||
if model.startswith("o1"):
|
||||
if stream:
|
||||
|
@ -626,8 +628,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
del extra_model_kwargs["stop"]
|
||||
|
||||
# chat model
|
||||
messages: Any = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
response = client.chat.completions.create(
|
||||
messages=[self._convert_prompt_message_to_dict(m) for m in prompt_messages],
|
||||
messages=messages,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
|
@ -946,23 +949,29 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
Convert PromptMessage to dict for OpenAI API
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
elif isinstance(message.content, list):
|
||||
sub_messages = []
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
if isinstance(message_content, TextPromptMessageContent):
|
||||
sub_message_dict = {"type": "text", "text": message_content.data}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif message_content.type == PromptMessageContentType.IMAGE:
|
||||
message_content = cast(ImagePromptMessageContent, message_content)
|
||||
elif isinstance(message_content, ImagePromptMessageContent):
|
||||
sub_message_dict = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": message_content.data, "detail": message_content.detail.value},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
elif isinstance(message_content, AudioPromptMessageContent):
|
||||
sub_message_dict = {
|
||||
"type": "input_audio",
|
||||
"input_audio": {
|
||||
"data": message_content.data,
|
||||
"format": message_content.format,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
||||
message_dict = {"role": "user", "content": sub_messages}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
|
|
|
@ -23,7 +23,7 @@ if TYPE_CHECKING:
|
|||
from core.file.models import File
|
||||
|
||||
|
||||
class ModelMode(enum.Enum):
|
||||
class ModelMode(str, enum.Enum):
|
||||
COMPLETION = "completion"
|
||||
CHAT = "chat"
|
||||
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
from typing import cast
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
|
@ -21,7 +22,7 @@ class PromptMessageUtil:
|
|||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if model_mode == ModelMode.CHAT.value:
|
||||
if model_mode == ModelMode.CHAT:
|
||||
tool_calls = []
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
|
@ -51,11 +52,9 @@ class PromptMessageUtil:
|
|||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "image",
|
||||
|
@ -63,6 +62,14 @@ class PromptMessageUtil:
|
|||
"detail": content.detail.value,
|
||||
}
|
||||
)
|
||||
elif isinstance(content, AudioPromptMessageContent):
|
||||
files.append(
|
||||
{
|
||||
"type": "audio",
|
||||
"data": content.data[:10] + "...[TRUNCATED]..." + content.data[-10:],
|
||||
"format": content.format,
|
||||
}
|
||||
)
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ class WordExtractor(BaseExtractor):
|
|||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
image_map[rel.target_part] = (
|
||||
f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/image-preview)"
|
||||
f"![image]({dify_config.CONSOLE_API_URL}/files/{upload_file.id}/file-preview)"
|
||||
)
|
||||
|
||||
return image_map
|
||||
|
|
|
@ -148,9 +148,7 @@ def _download_file_content(file: File) -> bytes:
|
|||
response.raise_for_status()
|
||||
return response.content
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if file.related_id is None:
|
||||
raise FileDownloadError("Missing file ID for local file")
|
||||
return file_manager.download(upload_file_id=file.related_id, tenant_id=file.tenant_id)
|
||||
return file_manager.download(file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
|
||||
except Exception as e:
|
||||
|
|
|
@ -169,9 +169,7 @@ class HttpExecutor:
|
|||
if file_variable is None:
|
||||
raise ValueError(f"cannot fetch file with selector {file_selector}")
|
||||
file = file_variable.value
|
||||
if file.related_id is None:
|
||||
raise ValueError(f"file {file.related_id} not found")
|
||||
self.content = file_manager.download(upload_file_id=file.related_id, tenant_id=file.tenant_id)
|
||||
self.content = file_manager.download(file)
|
||||
case "x-www-form-urlencoded":
|
||||
form_data = {
|
||||
self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template(
|
||||
|
@ -194,11 +192,7 @@ class HttpExecutor:
|
|||
files = {k: self.variable_pool.get_file(selector) for k, selector in file_selectors.items()}
|
||||
files = {k: v for k, v in files.items() if v is not None}
|
||||
files = {k: variable.value for k, variable in files.items()}
|
||||
files = {
|
||||
k: file_manager.download(upload_file_id=v.related_id, tenant_id=v.tenant_id)
|
||||
for k, v in files.items()
|
||||
if v.related_id is not None
|
||||
}
|
||||
files = {k: file_manager.download(v) for k, v in files.items() if v.related_id is not None}
|
||||
|
||||
self.data = form_data
|
||||
self.files = files
|
||||
|
|
|
@ -10,12 +10,14 @@ from core.entities.provider_entities import QuotaUnit
|
|||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
from core.model_runtime.entities import (
|
||||
AudioPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
@ -547,7 +549,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
# cuz vision detail is related to the configuration from FileUpload feature.
|
||||
content_item.detail = vision_detail
|
||||
prompt_message_content.append(content_item)
|
||||
elif content_item.type == PromptMessageContentType.TEXT:
|
||||
elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
|
||||
prompt_message_content.append(content_item)
|
||||
|
||||
if len(prompt_message_content) > 1:
|
||||
|
|
|
@ -1,18 +1,15 @@
|
|||
import mimetypes
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
|
||||
from core.file import File, FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType
|
||||
from core.helper import ssrf_proxy
|
||||
from enums import CreatedByRole
|
||||
from extensions.ext_database import db
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enums import CreatedByRole
|
||||
from models import MessageFile
|
||||
from models import MessageFile, ToolFile, UploadFile
|
||||
|
||||
|
||||
def build_from_message_files(
|
||||
|
@ -35,14 +32,19 @@ def build_from_message_file(
|
|||
tenant_id: str,
|
||||
config: FileExtraConfig,
|
||||
):
|
||||
return File(
|
||||
id=message_file.id,
|
||||
mapping = {
|
||||
"transfer_method": message_file.transfer_method,
|
||||
"url": message_file.url,
|
||||
"id": message_file.id,
|
||||
"type": message_file.type,
|
||||
"upload_file_id": message_file.upload_file_id,
|
||||
}
|
||||
return build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.value_of(message_file.type),
|
||||
transfer_method=FileTransferMethod.value_of(message_file.transfer_method),
|
||||
remote_url=message_file.url,
|
||||
related_id=message_file.upload_file_id or None,
|
||||
_extra_config=config,
|
||||
user_id=message_file.created_by,
|
||||
role=CreatedByRole(message_file.created_by_role),
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class AppIconUrlField(fields.Raw):
|
|||
from models.model import IconType
|
||||
|
||||
if obj.icon_type == IconType.IMAGE.value:
|
||||
return file_helpers.get_signed_image_url(obj.icon)
|
||||
return file_helpers.get_signed_file_url(obj.icon)
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -560,7 +560,7 @@ class DocumentSegment(db.Model):
|
|||
)
|
||||
|
||||
def get_sign_content(self):
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview"
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview"
|
||||
text = self.content
|
||||
matches = re.finditer(pattern, text)
|
||||
signed_urls = []
|
||||
|
@ -568,7 +568,7 @@ class DocumentSegment(db.Model):
|
|||
upload_file_id = match.group(1)
|
||||
nonce = os.urandom(16).hex()
|
||||
timestamp = str(int(time.time()))
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
|
|
@ -829,19 +829,29 @@ class Message(db.Model):
|
|||
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
|
||||
tool_file_id=tool_file_id, extension=extension
|
||||
)
|
||||
else:
|
||||
elif "file-preview" in url:
|
||||
# get upload file id
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
|
||||
upload_file_id = result.group(1)
|
||||
|
||||
if not upload_file_id:
|
||||
continue
|
||||
|
||||
sign_url = file_helpers.get_signed_image_url(upload_file_id)
|
||||
sign_url = file_helpers.get_signed_file_url(upload_file_id)
|
||||
elif "image-preview" in url:
|
||||
# image-preview is deprecated, use file-preview instead
|
||||
upload_file_id_pattern = r"\/files\/([\w-]+)\/image-preview?\?timestamp="
|
||||
result = re.search(upload_file_id_pattern, url)
|
||||
if not result:
|
||||
continue
|
||||
upload_file_id = result.group(1)
|
||||
if not upload_file_id:
|
||||
continue
|
||||
sign_url = file_helpers.get_signed_file_url(upload_file_id)
|
||||
else:
|
||||
continue
|
||||
|
||||
re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url)
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ nomic = "~3.1.2"
|
|||
novita-client = "~0.5.7"
|
||||
numpy = "~1.26.4"
|
||||
oci = "~2.135.1"
|
||||
openai = "~1.51.2"
|
||||
openai = "^1.52.0"
|
||||
openpyxl = "~3.1.5"
|
||||
pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
|
||||
psycopg2-binary = "~2.9.6"
|
||||
|
|
Loading…
Reference in New Issue
Block a user