mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat(api): Enhance multi modal support.
This commit is contained in:
parent
7838f9f3a3
commit
ea18dd1571
|
@ -233,6 +233,8 @@ VIKINGDB_SOCKET_TIMEOUT=30
|
||||||
UPLOAD_FILE_SIZE_LIMIT=15
|
UPLOAD_FILE_SIZE_LIMIT=15
|
||||||
UPLOAD_FILE_BATCH_LIMIT=5
|
UPLOAD_FILE_BATCH_LIMIT=5
|
||||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||||
|
UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
|
||||||
|
UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
|
||||||
|
|
||||||
# Model Configuration
|
# Model Configuration
|
||||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||||
|
@ -310,6 +312,7 @@ INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000
|
||||||
WORKFLOW_MAX_EXECUTION_STEPS=500
|
WORKFLOW_MAX_EXECUTION_STEPS=500
|
||||||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||||
WORKFLOW_CALL_MAX_DEPTH=5
|
WORKFLOW_CALL_MAX_DEPTH=5
|
||||||
|
MAX_VARIABLE_SIZE=204800
|
||||||
|
|
||||||
# App configuration
|
# App configuration
|
||||||
APP_MAX_EXECUTION_TIME=1200
|
APP_MAX_EXECUTION_TIME=1200
|
||||||
|
|
15
api/.vscode/launch.json.example
vendored
15
api/.vscode/launch.json.example
vendored
|
@ -1,8 +1,15 @@
|
||||||
{
|
{
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
|
"compounds": [
|
||||||
|
{
|
||||||
|
"name": "Launch Flask and Celery",
|
||||||
|
"configurations": ["Python: Flask", "Python: Celery"]
|
||||||
|
}
|
||||||
|
],
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "Python: Flask",
|
"name": "Python: Flask",
|
||||||
|
"consoleName": "Flask",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"python": "${workspaceFolder}/.venv/bin/python",
|
"python": "${workspaceFolder}/.venv/bin/python",
|
||||||
|
@ -17,12 +24,12 @@
|
||||||
},
|
},
|
||||||
"args": [
|
"args": [
|
||||||
"run",
|
"run",
|
||||||
"--host=0.0.0.0",
|
|
||||||
"--port=5001"
|
"--port=5001"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "Python: Celery",
|
"name": "Python: Celery",
|
||||||
|
"consoleName": "Celery",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"python": "${workspaceFolder}/.venv/bin/python",
|
"python": "${workspaceFolder}/.venv/bin/python",
|
||||||
|
@ -45,10 +52,10 @@
|
||||||
"-c",
|
"-c",
|
||||||
"1",
|
"1",
|
||||||
"--loglevel",
|
"--loglevel",
|
||||||
"info",
|
"DEBUG",
|
||||||
"-Q",
|
"-Q",
|
||||||
"dataset,generation,mail,ops_trace,app_deletion"
|
"dataset,generation,mail,ops_trace,app_deletion"
|
||||||
]
|
]
|
||||||
},
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -118,7 +118,7 @@ def create_app() -> Flask:
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=app.config.get("LOG_LEVEL"),
|
level=app.config.get("LOG_LEVEL"),
|
||||||
format=app.config.get("LOG_FORMAT"),
|
format=app.config["LOG_FORMAT"],
|
||||||
datefmt=app.config.get("LOG_DATEFORMAT"),
|
datefmt=app.config.get("LOG_DATEFORMAT"),
|
||||||
handlers=log_handlers,
|
handlers=log_handlers,
|
||||||
force=True,
|
force=True,
|
||||||
|
@ -135,6 +135,7 @@ def create_app() -> Flask:
|
||||||
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple()
|
||||||
|
|
||||||
for handler in logging.root.handlers:
|
for handler in logging.root.handlers:
|
||||||
|
assert handler.formatter
|
||||||
handler.formatter.converter = time_converter
|
handler.formatter.converter = time_converter
|
||||||
initialize_extensions(app)
|
initialize_extensions(app)
|
||||||
register_blueprints(app)
|
register_blueprints(app)
|
||||||
|
|
|
@ -19,7 +19,7 @@ from extensions.ext_redis import redis_client
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from libs.password import hash_password, password_pattern, valid_password
|
from libs.password import hash_password, password_pattern, valid_password
|
||||||
from libs.rsa import generate_key_pair
|
from libs.rsa import generate_key_pair
|
||||||
from models.account import Tenant
|
from models import Tenant
|
||||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||||
from models.dataset import Document as DatasetDocument
|
from models.dataset import Document as DatasetDocument
|
||||||
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
|
||||||
|
@ -457,14 +457,14 @@ def convert_to_agent_apps():
|
||||||
# fetch first 1000 apps
|
# fetch first 1000 apps
|
||||||
sql_query = """SELECT a.id AS id FROM apps a
|
sql_query = """SELECT a.id AS id FROM apps a
|
||||||
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
|
INNER JOIN app_model_configs am ON a.app_model_config_id=am.id
|
||||||
WHERE a.mode = 'chat'
|
WHERE a.mode = 'chat'
|
||||||
AND am.agent_mode is not null
|
AND am.agent_mode is not null
|
||||||
AND (
|
AND (
|
||||||
am.agent_mode like '%"strategy": "function_call"%'
|
am.agent_mode like '%"strategy": "function_call"%'
|
||||||
OR am.agent_mode like '%"strategy": "react"%'
|
OR am.agent_mode like '%"strategy": "react"%'
|
||||||
)
|
)
|
||||||
AND (
|
AND (
|
||||||
am.agent_mode like '{"enabled": true%'
|
am.agent_mode like '{"enabled": true%'
|
||||||
OR am.agent_mode like '{"max_iteration": %'
|
OR am.agent_mode like '{"max_iteration": %'
|
||||||
) ORDER BY a.created_at DESC LIMIT 1000
|
) ORDER BY a.created_at DESC LIMIT 1000
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Literal, Optional
|
||||||
|
|
||||||
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
from pydantic import AliasChoices, Field, HttpUrl, NegativeInt, NonNegativeInt, PositiveInt, computed_field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
@ -11,11 +11,11 @@ class SecurityConfig(BaseSettings):
|
||||||
Security-related configurations for the application
|
Security-related configurations for the application
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SECRET_KEY: Optional[str] = Field(
|
SECRET_KEY: str = Field(
|
||||||
description="Secret key for secure session cookie signing."
|
description="Secret key for secure session cookie signing."
|
||||||
"Make sure you are changing this key for your deployment with a strong key."
|
"Make sure you are changing this key for your deployment with a strong key."
|
||||||
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
"Generate a strong key using `openssl rand -base64 42` or set via the `SECRET_KEY` environment variable.",
|
||||||
default=None,
|
default="",
|
||||||
)
|
)
|
||||||
|
|
||||||
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
RESET_PASSWORD_TOKEN_EXPIRY_HOURS: PositiveInt = Field(
|
||||||
|
@ -177,6 +177,16 @@ class FileUploadConfig(BaseSettings):
|
||||||
default=10,
|
default=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
UPLOAD_VIDEO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||||
|
description="video file size limit in Megabytes for uploading files",
|
||||||
|
default=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
UPLOAD_AUDIO_FILE_SIZE_LIMIT: NonNegativeInt = Field(
|
||||||
|
description="audio file size limit in Megabytes for uploading files",
|
||||||
|
default=50,
|
||||||
|
)
|
||||||
|
|
||||||
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
BATCH_UPLOAD_LIMIT: NonNegativeInt = Field(
|
||||||
description="Maximum number of files allowed in a batch upload operation",
|
description="Maximum number of files allowed in a batch upload operation",
|
||||||
default=20,
|
default=20,
|
||||||
|
@ -355,8 +365,8 @@ class WorkflowConfig(BaseSettings):
|
||||||
)
|
)
|
||||||
|
|
||||||
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
MAX_VARIABLE_SIZE: PositiveInt = Field(
|
||||||
description="Maximum size in bytes for a single variable in workflows. Default to 5KB.",
|
description="Maximum size in bytes for a single variable in workflows. Default to 200 KB.",
|
||||||
default=5 * 1024,
|
default=200 * 1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -479,6 +489,7 @@ class RagEtlConfig(BaseSettings):
|
||||||
Configuration for RAG ETL processes
|
Configuration for RAG ETL processes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO: This config is not only for rag etl, it is also for file upload, we should move it to file upload config
|
||||||
ETL_TYPE: str = Field(
|
ETL_TYPE: str = Field(
|
||||||
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
|
description="RAG ETL type ('dify' or 'Unstructured'), default to 'dify'",
|
||||||
default="dify",
|
default="dify",
|
||||||
|
@ -540,7 +551,7 @@ class IndexingConfig(BaseSettings):
|
||||||
|
|
||||||
|
|
||||||
class ImageFormatConfig(BaseSettings):
|
class ImageFormatConfig(BaseSettings):
|
||||||
MULTIMODAL_SEND_IMAGE_FORMAT: str = Field(
|
MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
|
||||||
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
|
||||||
default="base64",
|
default="base64",
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
|
||||||
|
|
||||||
CURRENT_VERSION: str = Field(
|
CURRENT_VERSION: str = Field(
|
||||||
description="Dify version",
|
description="Dify version",
|
||||||
default="0.9.1",
|
default="0.10.0-beta2",
|
||||||
)
|
)
|
||||||
|
|
||||||
COMMIT_SHA: str = Field(
|
COMMIT_SHA: str = Field(
|
||||||
|
|
|
@ -1,2 +1,21 @@
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
HIDDEN_VALUE = "[__HIDDEN__]"
|
HIDDEN_VALUE = "[__HIDDEN__]"
|
||||||
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
UUID_NIL = "00000000-0000-0000-0000-000000000000"
|
||||||
|
|
||||||
|
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
||||||
|
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
||||||
|
|
||||||
|
VIDEO_EXTENSIONS = ["mp4", "mov", "mpeg", "mpga"]
|
||||||
|
VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
|
||||||
|
|
||||||
|
AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
|
||||||
|
AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
|
||||||
|
|
||||||
|
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
|
||||||
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
|
|
||||||
|
if dify_config.ETL_TYPE == "Unstructured":
|
||||||
|
DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "pdf", "html", "htm", "xlsx", "xls"]
|
||||||
|
DOCUMENT_EXTENSIONS.extend(("docx", "csv", "eml", "msg", "pptx", "ppt", "xml", "epub"))
|
||||||
|
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
if TYPE_CHECKING:
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||||
|
|
||||||
workflow_variable_pool: ContextVar[VariablePool] = ContextVar("workflow_variable_pool")
|
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||||
|
|
|
@ -22,7 +22,8 @@ from fields.conversation_fields import (
|
||||||
)
|
)
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotation
|
from models import Conversation, EndUser, Message, MessageAnnotation
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
class CompletionConversationApi(Resource):
|
class CompletionConversationApi(Resource):
|
||||||
|
|
|
@ -12,7 +12,7 @@ from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.app_fields import app_site_fields
|
from fields.app_fields import app_site_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import Site
|
from models import Site
|
||||||
|
|
||||||
|
|
||||||
def parse_app_site_args():
|
def parse_app_site_args():
|
||||||
|
|
|
@ -13,14 +13,14 @@ from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.segments import factory
|
from factories import variable_factory
|
||||||
from core.errors.error import AppInvokeQuotaExceededError
|
|
||||||
from fields.workflow_fields import workflow_fields
|
from fields.workflow_fields import workflow_fields
|
||||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
from services.app_dsl_service import AppDslService
|
from services.app_dsl_service import AppDslService
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import WorkflowHashNotEqualError
|
from services.errors.app import WorkflowHashNotEqualError
|
||||||
|
@ -101,9 +101,13 @@ class DraftWorkflowApi(Resource):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
environment_variables_list = args.get("environment_variables") or []
|
environment_variables_list = args.get("environment_variables") or []
|
||||||
environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list]
|
environment_variables = [
|
||||||
|
variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list
|
||||||
|
]
|
||||||
conversation_variables_list = args.get("conversation_variables") or []
|
conversation_variables_list = args.get("conversation_variables") or []
|
||||||
conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list]
|
conversation_variables = [
|
||||||
|
variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||||
|
]
|
||||||
workflow = workflow_service.sync_draft_workflow(
|
workflow = workflow_service.sync_draft_workflow(
|
||||||
app_model=app_model,
|
app_model=app_model,
|
||||||
graph=args["graph"],
|
graph=args["graph"],
|
||||||
|
@ -273,17 +277,15 @@ class DraftWorkflowRunApi(Resource):
|
||||||
parser.add_argument("files", type=list, required=False, location="json")
|
parser.add_argument("files", type=list, required=False, location="json")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
response = AppGenerateService.generate(
|
||||||
response = AppGenerateService.generate(
|
app_model=app_model,
|
||||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True
|
user=current_user,
|
||||||
)
|
args=args,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
streaming=True,
|
||||||
|
)
|
||||||
|
|
||||||
return helper.compact_generate_response(response)
|
return helper.compact_generate_response(response)
|
||||||
except (ValueError, AppInvokeQuotaExceededError) as e:
|
|
||||||
raise e
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("internal server error.")
|
|
||||||
raise InternalServerError()
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowTaskStopApi(Resource):
|
class WorkflowTaskStopApi(Resource):
|
||||||
|
|
|
@ -7,7 +7,8 @@ from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
from services.workflow_app_service import WorkflowAppService
|
from services.workflow_app_service import WorkflowAppService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,8 @@ from fields.workflow_run_fields import (
|
||||||
)
|
)
|
||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
from services.workflow_run_service import WorkflowRunService
|
from services.workflow_run_service import WorkflowRunService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -10,11 +10,11 @@ from controllers.console import api
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from enums import WorkflowRunTriggeredFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import DatetimeString
|
from libs.helper import DatetimeString
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
from models.workflow import WorkflowRunTriggeredFrom
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowDailyRunsStatistic(Resource):
|
class WorkflowDailyRunsStatistic(Resource):
|
||||||
|
|
|
@ -5,7 +5,8 @@ from typing import Optional, Union
|
||||||
from controllers.console.app.error import AppNotFoundError
|
from controllers.console.app.error import AppNotFoundError
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.model import App, AppMode
|
from models import App
|
||||||
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None):
|
||||||
|
|
|
@ -15,7 +15,7 @@ from controllers.console.setup import setup_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from libs.password import hash_password, valid_password
|
from libs.password import hash_password, valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.errors.account import RateLimitExceededError
|
from services.errors.account import RateLimitExceededError
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from controllers.console import api
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from libs.helper import email, extract_remote_ip
|
from libs.helper import email, extract_remote_ip
|
||||||
from libs.password import valid_password
|
from libs.password import valid_password
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from services.account_service import AccountService, TenantService
|
from services.account_service import AccountService, TenantService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,8 @@ from constants.languages import languages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import extract_remote_ip
|
from libs.helper import extract_remote_ip
|
||||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||||
from models.account import Account, AccountStatus
|
from models import Account
|
||||||
|
from models.account import AccountStatus
|
||||||
from services.account_service import AccountService, RegisterService, TenantService
|
from services.account_service import AccountService, RegisterService, TenantService
|
||||||
|
|
||||||
from .. import api
|
from .. import api
|
||||||
|
|
|
@ -15,8 +15,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import Document
|
from models import DataSourceOauthBinding, Document
|
||||||
from models.source import DataSourceOauthBinding
|
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||||
|
|
||||||
|
|
|
@ -24,8 +24,8 @@ from fields.app_fields import related_app_list
|
||||||
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields
|
||||||
from fields.document_fields import document_status_fields
|
from fields.document_fields import document_status_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment
|
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||||
from models.model import ApiToken, UploadFile
|
from models.dataset import DatasetPermissionEnum
|
||||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,8 +46,7 @@ from fields.document_fields import (
|
||||||
document_with_segments_fields,
|
document_with_segments_fields,
|
||||||
)
|
)
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment
|
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||||
from models.model import UploadFile
|
|
||||||
from services.dataset_service import DatasetService, DocumentService
|
from services.dataset_service import DatasetService, DocumentService
|
||||||
from tasks.add_document_to_index_task import add_document_to_index_task
|
from tasks.add_document_to_index_task import add_document_to_index_task
|
||||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||||
|
|
|
@ -24,7 +24,7 @@ from extensions.ext_database import db
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from fields.segment_fields import segment_fields
|
from fields.segment_fields import segment_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.dataset import DocumentSegment
|
from models import DocumentSegment
|
||||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import current_user
|
from flask_login import current_user
|
||||||
from flask_restful import Resource, marshal_with
|
from flask_restful import Resource, marshal_with
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
from constants import DOCUMENT_EXTENSIONS
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.datasets.error import (
|
from controllers.console.datasets.error import (
|
||||||
FileTooLargeError,
|
FileTooLargeError,
|
||||||
|
@ -13,9 +16,10 @@ from controllers.console.datasets.error import (
|
||||||
)
|
)
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
|
||||||
from fields.file_fields import file_fields, upload_config_fields
|
from core.helper import ssrf_proxy
|
||||||
|
from fields.file_fields import file_fields, remote_file_info_fields, upload_config_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.file_service import ALLOWED_EXTENSIONS, UNSTRUCTURED_ALLOWED_EXTENSIONS, FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
PREVIEW_WORDS_LIMIT = 3000
|
PREVIEW_WORDS_LIMIT = 3000
|
||||||
|
|
||||||
|
@ -51,7 +55,7 @@ class FileApi(Resource):
|
||||||
if len(request.files) > 1:
|
if len(request.files) > 1:
|
||||||
raise TooManyFilesError()
|
raise TooManyFilesError()
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(file, current_user)
|
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
@ -75,11 +79,24 @@ class FileSupportTypeApi(Resource):
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def get(self):
|
def get(self):
|
||||||
etl_type = dify_config.ETL_TYPE
|
return {"allowed_extensions": DOCUMENT_EXTENSIONS}
|
||||||
allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS
|
|
||||||
return {"allowed_extensions": allowed_extensions}
|
|
||||||
|
class RemoteFileInfoApi(Resource):
|
||||||
|
@marshal_with(remote_file_info_fields)
|
||||||
|
def get(self, url):
|
||||||
|
decoded_url = urllib.parse.unquote(url)
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.head(decoded_url)
|
||||||
|
return {
|
||||||
|
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||||
|
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, "/files/upload")
|
api.add_resource(FileApi, "/files/upload")
|
||||||
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
|
||||||
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
api.add_resource(FileSupportTypeApi, "/files/support-type")
|
||||||
|
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||||
|
|
|
@ -11,7 +11,7 @@ from controllers.console.wraps import account_initialization_required, cloud_edi
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.installed_app_fields import installed_app_list_fields
|
from fields.installed_app_fields import installed_app_list_fields
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import App, InstalledApp, RecommendedApp
|
from models import App, InstalledApp, RecommendedApp
|
||||||
from services.account_service import TenantService
|
from services.account_service import TenantService
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ message_fields = {
|
||||||
"inputs": fields.Raw,
|
"inputs": fields.Raw,
|
||||||
"query": fields.String,
|
"query": fields.String,
|
||||||
"answer": fields.String,
|
"answer": fields.String,
|
||||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.model import InstalledApp
|
from models import InstalledApp
|
||||||
|
|
||||||
|
|
||||||
def installed_app_required(view=None):
|
def installed_app_required(view=None):
|
||||||
|
|
|
@ -20,7 +20,7 @@ from extensions.ext_database import db
|
||||||
from fields.member_fields import account_fields
|
from fields.member_fields import account_fields
|
||||||
from libs.helper import TimestampField, timezone
|
from libs.helper import TimestampField, timezone
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from models.account import AccountIntegrate, InvitationCode
|
from models import AccountIntegrate, InvitationCode
|
||||||
from services.account_service import AccountService
|
from services.account_service import AccountService
|
||||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||||
|
|
||||||
|
|
|
@ -360,16 +360,15 @@ class ToolWorkflowProviderCreateApi(Resource):
|
||||||
args = reqparser.parse_args()
|
args = reqparser.parse_args()
|
||||||
|
|
||||||
return WorkflowToolManageService.create_workflow_tool(
|
return WorkflowToolManageService.create_workflow_tool(
|
||||||
user_id,
|
user_id=user_id,
|
||||||
tenant_id,
|
tenant_id=tenant_id,
|
||||||
args["workflow_app_id"],
|
workflow_app_id=args["workflow_app_id"],
|
||||||
args["name"],
|
name=args["name"],
|
||||||
args["label"],
|
label=args["label"],
|
||||||
args["icon"],
|
icon=args["icon"],
|
||||||
args["description"],
|
description=args["description"],
|
||||||
args["parameters"],
|
parameters=args["parameters"],
|
||||||
args["privacy_policy"],
|
privacy_policy=args["privacy_policy"],
|
||||||
args.get("labels", []),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -198,7 +198,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upload_file = FileService.upload_file(file, current_user, True)
|
upload_file = FileService.upload_file(file=file, user=current_user)
|
||||||
|
|
||||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
raise FileTooLargeError(file_too_large_error.description)
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
|
|
|
@ -21,7 +21,36 @@ class ImagePreviewApi(Resource):
|
||||||
return {"content": "Invalid request."}, 400
|
return {"content": "Invalid request."}, 400
|
||||||
|
|
||||||
try:
|
try:
|
||||||
generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign)
|
generator, mimetype = FileService.get_image_preview(
|
||||||
|
file_id=file_id,
|
||||||
|
timestamp=timestamp,
|
||||||
|
nonce=nonce,
|
||||||
|
sign=sign,
|
||||||
|
)
|
||||||
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
return Response(generator, mimetype=mimetype)
|
||||||
|
|
||||||
|
|
||||||
|
class FilePreviewApi(Resource):
|
||||||
|
def get(self, file_id):
|
||||||
|
file_id = str(file_id)
|
||||||
|
|
||||||
|
timestamp = request.args.get("timestamp")
|
||||||
|
nonce = request.args.get("nonce")
|
||||||
|
sign = request.args.get("sign")
|
||||||
|
|
||||||
|
if not timestamp or not nonce or not sign:
|
||||||
|
return {"content": "Invalid request."}, 400
|
||||||
|
|
||||||
|
try:
|
||||||
|
generator, mimetype = FileService.get_signed_file_preview(
|
||||||
|
file_id=file_id,
|
||||||
|
timestamp=timestamp,
|
||||||
|
nonce=nonce,
|
||||||
|
sign=sign,
|
||||||
|
)
|
||||||
except services.errors.file.UnsupportedFileTypeError:
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
raise UnsupportedFileTypeError()
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
@ -49,4 +78,5 @@ class WorkspaceWebappLogoApi(Resource):
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
|
api.add_resource(ImagePreviewApi, "/files/<uuid:file_id>/image-preview")
|
||||||
|
api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/file-preview")
|
||||||
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
|
api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces/<uuid:workspace_id>/webapp-logo")
|
||||||
|
|
|
@ -48,7 +48,7 @@ class MessageListApi(Resource):
|
||||||
"tool_input": fields.String,
|
"tool_input": fields.String,
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"observation": fields.String,
|
"observation": fields.String,
|
||||||
"message_files": fields.List(fields.String, attribute="files"),
|
"message_files": fields.List(fields.String),
|
||||||
}
|
}
|
||||||
|
|
||||||
message_fields = {
|
message_fields = {
|
||||||
|
@ -58,7 +58,7 @@ class MessageListApi(Resource):
|
||||||
"inputs": fields.Raw,
|
"inputs": fields.Raw,
|
||||||
"query": fields.String,
|
"query": fields.String,
|
||||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import urllib.parse
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_restful import marshal_with
|
from flask_restful import marshal_with
|
||||||
|
|
||||||
|
@ -5,7 +7,8 @@ import services
|
||||||
from controllers.web import api
|
from controllers.web import api
|
||||||
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError
|
||||||
from controllers.web.wraps import WebApiResource
|
from controllers.web.wraps import WebApiResource
|
||||||
from fields.file_fields import file_fields
|
from core.helper import ssrf_proxy
|
||||||
|
from fields.file_fields import file_fields, remote_file_info_fields
|
||||||
from services.file_service import FileService
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,4 +34,19 @@ class FileApi(WebApiResource):
|
||||||
return upload_file, 201
|
return upload_file, 201
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteFileInfoApi(WebApiResource):
|
||||||
|
@marshal_with(remote_file_info_fields)
|
||||||
|
def get(self, url):
|
||||||
|
decoded_url = urllib.parse.unquote(url)
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.head(decoded_url)
|
||||||
|
return {
|
||||||
|
"file_type": response.headers.get("Content-Type", "application/octet-stream"),
|
||||||
|
"file_length": int(response.headers.get("Content-Length", 0)),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": str(e)}, 400
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(FileApi, "/files/upload")
|
api.add_resource(FileApi, "/files/upload")
|
||||||
|
api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
|
||||||
|
|
|
@ -22,6 +22,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
from fields.conversation_fields import message_file_fields
|
from fields.conversation_fields import message_file_fields
|
||||||
from fields.message_fields import agent_thought_fields
|
from fields.message_fields import agent_thought_fields
|
||||||
|
from fields.raws import FilesContainedField
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import TimestampField, uuid_value
|
from libs.helper import TimestampField, uuid_value
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
@ -58,10 +59,10 @@ class MessageListApi(WebApiResource):
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"conversation_id": fields.String,
|
"conversation_id": fields.String,
|
||||||
"parent_message_id": fields.String,
|
"parent_message_id": fields.String,
|
||||||
"inputs": fields.Raw,
|
"inputs": FilesContainedField,
|
||||||
"query": fields.String,
|
"query": fields.String,
|
||||||
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
"answer": fields.String(attribute="re_sign_file_url_answer"),
|
||||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
"retriever_resources": fields.List(fields.Nested(retriever_resource_fields)),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
|
|
|
@ -17,7 +17,7 @@ message_fields = {
|
||||||
"inputs": fields.Raw,
|
"inputs": fields.Raw,
|
||||||
"query": fields.String,
|
"query": fields.String,
|
||||||
"answer": fields.String,
|
"answer": fields.String,
|
||||||
"message_files": fields.List(fields.Nested(message_file_fields), attribute="files"),
|
"message_files": fields.List(fields.Nested(message_file_fields)),
|
||||||
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
"feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True),
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,13 +16,14 @@ from core.app.entities.app_invoke_entities import (
|
||||||
)
|
)
|
||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||||
from core.file.message_file_parser import MessageFileParser
|
from core.file import file_manager
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities import (
|
||||||
from core.model_runtime.entities.message_entities import (
|
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
LLMUsage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageTool,
|
PromptMessageTool,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
@ -40,9 +41,9 @@ from core.tools.entities.tool_entities import (
|
||||||
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Conversation, Message, MessageAgentThought
|
from factories import file_factory
|
||||||
|
from models.model import Conversation, Message, MessageAgentThought, MessageFile
|
||||||
from models.tools import ToolConversationVariables
|
from models.tools import ToolConversationVariables
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -66,23 +67,6 @@ class BaseAgentRunner(AppRunner):
|
||||||
db_variables: Optional[ToolConversationVariables] = None,
|
db_variables: Optional[ToolConversationVariables] = None,
|
||||||
model_instance: ModelInstance = None,
|
model_instance: ModelInstance = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
Agent runner
|
|
||||||
:param tenant_id: tenant id
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param conversation: conversation
|
|
||||||
:param app_config: app generate entity
|
|
||||||
:param model_config: model config
|
|
||||||
:param config: dataset config
|
|
||||||
:param queue_manager: queue manager
|
|
||||||
:param message: message
|
|
||||||
:param user_id: user id
|
|
||||||
:param memory: memory
|
|
||||||
:param prompt_messages: prompt messages
|
|
||||||
:param variables_pool: variables pool
|
|
||||||
:param db_variables: db variables
|
|
||||||
:param model_instance: model instance
|
|
||||||
"""
|
|
||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
self.conversation = conversation
|
self.conversation = conversation
|
||||||
|
@ -180,7 +164,7 @@ class BaseAgentRunner(AppRunner):
|
||||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
parameter_type = parameter.type.as_normal_type()
|
||||||
enum = []
|
enum = []
|
||||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||||
enum = [option.value for option in parameter.options]
|
enum = [option.value for option in parameter.options]
|
||||||
|
@ -265,7 +249,7 @@ class BaseAgentRunner(AppRunner):
|
||||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)
|
parameter_type = parameter.type.as_normal_type()
|
||||||
enum = []
|
enum = []
|
||||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||||
enum = [option.value for option in parameter.options]
|
enum = [option.value for option in parameter.options]
|
||||||
|
@ -511,26 +495,24 @@ class BaseAgentRunner(AppRunner):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||||
message_file_parser = MessageFileParser(
|
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
app_id=self.app_config.app_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
files = message.message_files
|
|
||||||
if files:
|
if files:
|
||||||
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())
|
||||||
|
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
file_objs = file_factory.build_from_message_files(
|
||||||
|
message_files=files, tenant_id=self.tenant_id, config=file_extra_config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
if not file_objs:
|
if not file_objs:
|
||||||
return UserPromptMessage(content=message.query)
|
return UserPromptMessage(content=message.query)
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
return UserPromptMessage(content=prompt_message_contents)
|
return UserPromptMessage(content=prompt_message_contents)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from core.agent.cot_agent_runner import CotAgentRunner
|
from core.agent.cot_agent_runner import CotAgentRunner
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.file import file_manager
|
||||||
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
|
@ -32,9 +34,10 @@ class CotChatAgentRunner(CotAgentRunner):
|
||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file_obj in self.files:
|
for file_obj in self.files:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -7,10 +7,15 @@ from typing import Any, Optional, Union
|
||||||
from core.agent.base_agent_runner import BaseAgentRunner
|
from core.agent.base_agent_runner import BaseAgentRunner
|
||||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.file import file_manager
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
|
LLMResult,
|
||||||
|
LLMResultChunk,
|
||||||
|
LLMResultChunkDelta,
|
||||||
|
LLMUsage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageContentType,
|
PromptMessageContentType,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
@ -390,9 +395,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||||
Organize user query
|
Organize user query
|
||||||
"""
|
"""
|
||||||
if self.files:
|
if self.files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file_obj in self.files:
|
for file_obj in self.files:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -53,12 +53,11 @@ class BasicVariablesConfigManager:
|
||||||
VariableEntity(
|
VariableEntity(
|
||||||
type=variable_type,
|
type=variable_type,
|
||||||
variable=variable.get("variable"),
|
variable=variable.get("variable"),
|
||||||
description=variable.get("description"),
|
description=variable.get("description", ""),
|
||||||
label=variable.get("label"),
|
label=variable.get("label"),
|
||||||
required=variable.get("required", False),
|
required=variable.get("required", False),
|
||||||
max_length=variable.get("max_length"),
|
max_length=variable.get("max_length"),
|
||||||
options=variable.get("options"),
|
options=variable.get("options", []),
|
||||||
default=variable.get("default"),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
|
from collections.abc import Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.file.file_obj import FileExtraConfig
|
from core.file import FileExtraConfig, FileTransferMethod, FileType
|
||||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||||
from models import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
|
|
||||||
class ModelConfigEntity(BaseModel):
|
class ModelConfigEntity(BaseModel):
|
||||||
|
@ -69,7 +70,7 @@ class PromptTemplateEntity(BaseModel):
|
||||||
ADVANCED = "advanced"
|
ADVANCED = "advanced"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "PromptType":
|
def value_of(cls, value: str):
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
@ -93,6 +94,8 @@ class VariableEntityType(str, Enum):
|
||||||
PARAGRAPH = "paragraph"
|
PARAGRAPH = "paragraph"
|
||||||
NUMBER = "number"
|
NUMBER = "number"
|
||||||
EXTERNAL_DATA_TOOL = "external_data_tool"
|
EXTERNAL_DATA_TOOL = "external_data_tool"
|
||||||
|
FILE = "file"
|
||||||
|
FILE_LIST = "file-list"
|
||||||
|
|
||||||
|
|
||||||
class VariableEntity(BaseModel):
|
class VariableEntity(BaseModel):
|
||||||
|
@ -102,13 +105,14 @@ class VariableEntity(BaseModel):
|
||||||
|
|
||||||
variable: str
|
variable: str
|
||||||
label: str
|
label: str
|
||||||
description: Optional[str] = None
|
description: str = ""
|
||||||
type: VariableEntityType
|
type: VariableEntityType
|
||||||
required: bool = False
|
required: bool = False
|
||||||
max_length: Optional[int] = None
|
max_length: Optional[int] = None
|
||||||
options: Optional[list[str]] = None
|
options: Sequence[str] = Field(default_factory=list)
|
||||||
default: Optional[str] = None
|
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||||
hint: Optional[str] = None
|
allowed_file_extensions: Sequence[str] = Field(default_factory=list)
|
||||||
|
allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ExternalDataVariableEntity(BaseModel):
|
class ExternalDataVariableEntity(BaseModel):
|
||||||
|
@ -136,7 +140,7 @@ class DatasetRetrieveConfigEntity(BaseModel):
|
||||||
MULTIPLE = "multiple"
|
MULTIPLE = "multiple"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "RetrieveStrategy":
|
def value_of(cls, value: str):
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from core.file.file_obj import FileExtraConfig
|
from core.file.models import FileExtraConfig
|
||||||
|
from models import FileUploadConfig
|
||||||
|
|
||||||
|
|
||||||
class FileUploadConfigManager:
|
class FileUploadConfigManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert(cls, config: Mapping[str, Any], is_vision: bool = True) -> Optional[FileExtraConfig]:
|
def convert(cls, config: Mapping[str, Any], is_vision: bool = True):
|
||||||
"""
|
"""
|
||||||
Convert model config to model config
|
Convert model config to model config
|
||||||
|
|
||||||
|
@ -15,19 +16,18 @@ class FileUploadConfigManager:
|
||||||
"""
|
"""
|
||||||
file_upload_dict = config.get("file_upload")
|
file_upload_dict = config.get("file_upload")
|
||||||
if file_upload_dict:
|
if file_upload_dict:
|
||||||
if file_upload_dict.get("image"):
|
if file_upload_dict.get("enabled"):
|
||||||
if "enabled" in file_upload_dict["image"] and file_upload_dict["image"]["enabled"]:
|
data = {
|
||||||
image_config = {
|
"image_config": {
|
||||||
"number_limits": file_upload_dict["image"]["number_limits"],
|
"number_limits": file_upload_dict["number_limits"],
|
||||||
"transfer_methods": file_upload_dict["image"]["transfer_methods"],
|
"transfer_methods": file_upload_dict["allowed_file_upload_methods"],
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if is_vision:
|
if is_vision:
|
||||||
image_config["detail"] = file_upload_dict["image"]["detail"]
|
data["image_config"]["detail"] = file_upload_dict.get("image", {}).get("detail", "low")
|
||||||
|
|
||||||
return FileExtraConfig(image_config=image_config)
|
return FileExtraConfig.model_validate(data)
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
def validate_and_set_defaults(cls, config: dict, is_vision: bool = True) -> tuple[dict, list[str]]:
|
||||||
|
@ -39,29 +39,7 @@ class FileUploadConfigManager:
|
||||||
"""
|
"""
|
||||||
if not config.get("file_upload"):
|
if not config.get("file_upload"):
|
||||||
config["file_upload"] = {}
|
config["file_upload"] = {}
|
||||||
|
else:
|
||||||
if not isinstance(config["file_upload"], dict):
|
FileUploadConfig.model_validate(config["file_upload"])
|
||||||
raise ValueError("file_upload must be of dict type")
|
|
||||||
|
|
||||||
# check image config
|
|
||||||
if not config["file_upload"].get("image"):
|
|
||||||
config["file_upload"]["image"] = {"enabled": False}
|
|
||||||
|
|
||||||
if config["file_upload"]["image"]["enabled"]:
|
|
||||||
number_limits = config["file_upload"]["image"]["number_limits"]
|
|
||||||
if number_limits < 1 or number_limits > 6:
|
|
||||||
raise ValueError("number_limits must be in [1, 6]")
|
|
||||||
|
|
||||||
if is_vision:
|
|
||||||
detail = config["file_upload"]["image"]["detail"]
|
|
||||||
if detail not in {"high", "low"}:
|
|
||||||
raise ValueError("detail must be in ['high', 'low']")
|
|
||||||
|
|
||||||
transfer_methods = config["file_upload"]["image"]["transfer_methods"]
|
|
||||||
if not isinstance(transfer_methods, list):
|
|
||||||
raise ValueError("transfer_methods must be of list type")
|
|
||||||
for method in transfer_methods:
|
|
||||||
if method not in {"remote_url", "local_file"}:
|
|
||||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
|
||||||
|
|
||||||
return config, ["file_upload"]
|
return config, ["file_upload"]
|
||||||
|
|
|
@ -17,6 +17,6 @@ class WorkflowVariablesConfigManager:
|
||||||
|
|
||||||
# variables
|
# variables
|
||||||
for variable in user_input_form:
|
for variable in user_input_form:
|
||||||
variables.append(VariableEntity(**variable))
|
variables.append(VariableEntity.model_validate(variable))
|
||||||
|
|
||||||
return variables
|
return variables
|
||||||
|
|
|
@ -20,10 +20,11 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
@ -95,10 +96,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
@ -106,8 +113,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
trace_manager = TraceQueueManager(
|
||||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id
|
||||||
|
)
|
||||||
|
|
||||||
if invoke_from == InvokeFrom.DEBUGGER:
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
|
@ -119,7 +127,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs
|
||||||
|
if conversation
|
||||||
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id"),
|
||||||
|
|
|
@ -1,30 +1,26 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.app_invoke_entities import (
|
|
||||||
AdvancedChatAppGenerateEntity,
|
|
||||||
InvokeFrom,
|
|
||||||
)
|
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
QueueAnnotationReplyEvent,
|
QueueAnnotationReplyEvent,
|
||||||
QueueStopEvent,
|
QueueStopEvent,
|
||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
)
|
)
|
||||||
from core.moderation.base import ModerationError
|
from core.moderation.base import ModerationError
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
from core.workflow.entities.node_entities import UserFrom
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import UserFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, Conversation, EndUser, Message
|
from models.model import App, Conversation, EndUser, Message
|
||||||
from models.workflow import ConversationVariable, WorkflowType
|
from models.workflow import ConversationVariable, WorkflowType
|
||||||
|
@ -44,12 +40,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
conversation: Conversation,
|
conversation: Conversation,
|
||||||
message: Message,
|
message: Message,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
|
||||||
:param application_generate_entity: application generate entity
|
|
||||||
:param queue_manager: application queue manager
|
|
||||||
:param conversation: conversation
|
|
||||||
:param message: message
|
|
||||||
"""
|
|
||||||
super().__init__(queue_manager)
|
super().__init__(queue_manager)
|
||||||
|
|
||||||
self.application_generate_entity = application_generate_entity
|
self.application_generate_entity = application_generate_entity
|
||||||
|
@ -57,10 +47,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
self.message = message
|
self.message = message
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
"""
|
|
||||||
Run application
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
app_config = self.application_generate_entity.app_config
|
app_config = self.application_generate_entity.app_config
|
||||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||||
|
|
||||||
|
@ -81,7 +67,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
user_id = self.application_generate_entity.user_id
|
user_id = self.application_generate_entity.user_id
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
if dify_config.DEBUG:
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
if self.application_generate_entity.single_iteration_run:
|
if self.application_generate_entity.single_iteration_run:
|
||||||
|
@ -201,15 +187,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
query: str,
|
query: str,
|
||||||
message_id: str,
|
message_id: str,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
|
||||||
Handle input moderation
|
|
||||||
:param app_record: app record
|
|
||||||
:param app_generate_entity: application generate entity
|
|
||||||
:param inputs: inputs
|
|
||||||
:param query: query
|
|
||||||
:param message_id: message id
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# process sensitive_word_avoidance
|
# process sensitive_word_avoidance
|
||||||
_, inputs, query = self.moderation_for_inputs(
|
_, inputs, query = self.moderation_for_inputs(
|
||||||
|
@ -229,14 +206,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
def handle_annotation_reply(
|
def handle_annotation_reply(
|
||||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
|
||||||
Handle annotation reply
|
|
||||||
:param app_record: app record
|
|
||||||
:param message: message
|
|
||||||
:param query: query
|
|
||||||
:param app_generate_entity: application generate entity
|
|
||||||
"""
|
|
||||||
# annotation reply
|
|
||||||
annotation_reply = self.query_app_annotations_to_reply(
|
annotation_reply = self.query_app_annotations_to_reply(
|
||||||
app_record=app_record,
|
app_record=app_record,
|
||||||
message=message,
|
message=message,
|
||||||
|
@ -258,8 +227,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||||
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
def _complete_with_stream_output(self, text: str, stopped_by: QueueStopEvent.StopBy) -> None:
|
||||||
"""
|
"""
|
||||||
Direct output
|
Direct output
|
||||||
:param text: text
|
|
||||||
:return:
|
|
||||||
"""
|
"""
|
||||||
self._publish_event(QueueTextChunkEvent(text=text))
|
self._publish_event(QueueTextChunkEvent(text=text))
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
|
||||||
|
@ -9,6 +9,7 @@ from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGenerator
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AdvancedChatAppGenerateEntity,
|
AdvancedChatAppGenerateEntity,
|
||||||
|
InvokeFrom,
|
||||||
)
|
)
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
QueueAdvancedChatMessageEndEvent,
|
QueueAdvancedChatMessageEndEvent,
|
||||||
|
@ -50,10 +51,11 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from enums.workflow_nodes import NodeType
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from models import Conversation, EndUser, Message, MessageFile
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import Conversation, EndUser, Message
|
|
||||||
from models.workflow import (
|
from models.workflow import (
|
||||||
Workflow,
|
Workflow,
|
||||||
WorkflowNodeExecution,
|
WorkflowNodeExecution,
|
||||||
|
@ -120,6 +122,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
self._wip_workflow_node_executions = {}
|
self._wip_workflow_node_executions = {}
|
||||||
|
|
||||||
self._conversation_name_generate_thread = None
|
self._conversation_name_generate_thread = None
|
||||||
|
self._recorded_files: list[Mapping[str, Any]] = []
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
"""
|
"""
|
||||||
|
@ -298,6 +301,10 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
elif isinstance(event, QueueNodeSucceededEvent):
|
elif isinstance(event, QueueNodeSucceededEvent):
|
||||||
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
workflow_node_execution = self._handle_workflow_node_execution_success(event)
|
||||||
|
|
||||||
|
# Record files if it's an answer node or end node
|
||||||
|
if event.node_type in [NodeType.ANSWER, NodeType.END]:
|
||||||
|
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
|
||||||
|
|
||||||
response = self._workflow_node_finish_to_stream_response(
|
response = self._workflow_node_finish_to_stream_response(
|
||||||
event=event,
|
event=event,
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
|
@ -364,7 +371,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
total_tokens=graph_runtime_state.total_tokens,
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
total_steps=graph_runtime_state.node_run_steps,
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
outputs=json.dumps(event.outputs) if event.outputs else None,
|
outputs=event.outputs,
|
||||||
conversation_id=self._conversation.id,
|
conversation_id=self._conversation.id,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
|
@ -490,10 +497,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
self._conversation_name_generate_thread.join()
|
self._conversation_name_generate_thread.join()
|
||||||
|
|
||||||
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
|
||||||
"""
|
|
||||||
Save message.
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
self._refetch_message()
|
self._refetch_message()
|
||||||
|
|
||||||
self._message.answer = self._task_state.answer
|
self._message.answer = self._task_state.answer
|
||||||
|
@ -501,6 +504,22 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
self._message.message_metadata = (
|
self._message.message_metadata = (
|
||||||
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
|
||||||
)
|
)
|
||||||
|
message_files = [
|
||||||
|
MessageFile(
|
||||||
|
message_id=self._message.id,
|
||||||
|
type=file["type"],
|
||||||
|
transfer_method=file["transfer_method"],
|
||||||
|
url=file["remote_url"],
|
||||||
|
belongs_to="assistant",
|
||||||
|
upload_file_id=file["related_id"],
|
||||||
|
created_by_role="account"
|
||||||
|
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
|
else "end_user",
|
||||||
|
created_by=self._message.from_account_id or self._message.from_end_user_id or "",
|
||||||
|
)
|
||||||
|
for file in self._recorded_files
|
||||||
|
]
|
||||||
|
db.session.add_all(message_files)
|
||||||
|
|
||||||
if graph_runtime_state and graph_runtime_state.llm_usage:
|
if graph_runtime_state and graph_runtime_state.llm_usage:
|
||||||
usage = graph_runtime_state.llm_usage
|
usage = graph_runtime_state.llm_usage
|
||||||
|
@ -540,7 +559,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
||||||
del extras["metadata"]["annotation_reply"]
|
del extras["metadata"]["annotation_reply"]
|
||||||
|
|
||||||
return MessageEndStreamResponse(
|
return MessageEndStreamResponse(
|
||||||
task_id=self._application_generate_entity.task_id, id=self._message.id, **extras
|
task_id=self._application_generate_entity.task_id, id=self._message.id, files=self._recorded_files, **extras
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||||
|
|
|
@ -17,12 +17,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskSt
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from factories import file_factory
|
||||||
from models.model import App, EndUser
|
from models import Account, App, EndUser
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -49,7 +49,12 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
) -> dict: ...
|
) -> dict: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True
|
self,
|
||||||
|
app_model: App,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Any,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
stream: bool = True,
|
||||||
) -> Union[dict, Generator[dict, None, None]]:
|
) -> Union[dict, Generator[dict, None, None]]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
|
@ -97,12 +102,19 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||||
|
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args.get("files") or []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
@ -115,8 +127,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
)
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id)
|
||||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = AgentChatAppGenerateEntity(
|
application_generate_entity = AgentChatAppGenerateEntity(
|
||||||
|
@ -124,7 +135,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs
|
||||||
|
if conversation
|
||||||
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id"),
|
||||||
|
|
|
@ -1,35 +1,92 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig, VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
|
from core.file import File, FileExtraConfig
|
||||||
|
from factories import file_factory
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.app.app_config.entities import AppConfig, VariableEntity
|
||||||
|
from enums import CreatedByRole
|
||||||
|
|
||||||
|
|
||||||
class BaseAppGenerator:
|
class BaseAppGenerator:
|
||||||
def _get_cleaned_inputs(self, user_inputs: Optional[Mapping[str, Any]], app_config: AppConfig) -> Mapping[str, Any]:
|
def _prepare_user_inputs(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
user_inputs: Optional[Mapping[str, Any]],
|
||||||
|
app_config: "AppConfig",
|
||||||
|
user_id: str,
|
||||||
|
role: "CreatedByRole",
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
user_inputs = user_inputs or {}
|
user_inputs = user_inputs or {}
|
||||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||||
variables = app_config.variables
|
variables = app_config.variables
|
||||||
filtered_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||||
filtered_inputs = {k: self._sanitize_value(v) for k, v in filtered_inputs.items()}
|
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||||
return filtered_inputs
|
# Convert files in inputs to File
|
||||||
|
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||||
|
# Convert single file to File
|
||||||
|
files_inputs = {
|
||||||
|
k: file_factory.build_from_mapping(
|
||||||
|
mapping=v,
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=role,
|
||||||
|
config=FileExtraConfig(
|
||||||
|
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||||
|
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
|
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for k, v in user_inputs.items()
|
||||||
|
if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE
|
||||||
|
}
|
||||||
|
# Convert list of files to File
|
||||||
|
file_list_inputs = {
|
||||||
|
k: file_factory.build_from_mappings(
|
||||||
|
mappings=v,
|
||||||
|
tenant_id=app_config.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
role=role,
|
||||||
|
config=FileExtraConfig(
|
||||||
|
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||||
|
allowed_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||||
|
allowed_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for k, v in user_inputs.items()
|
||||||
|
if isinstance(v, list)
|
||||||
|
# Ensure skip List<File>
|
||||||
|
and all(isinstance(item, dict) for item in v)
|
||||||
|
and entity_dictionary[k].type == VariableEntityType.FILE_LIST
|
||||||
|
}
|
||||||
|
# Merge all inputs
|
||||||
|
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
|
||||||
|
|
||||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: VariableEntity):
|
# Check if all files are converted to File
|
||||||
user_input_value = inputs.get(var.variable)
|
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
|
||||||
if var.required and not user_input_value:
|
raise ValueError("Invalid input type")
|
||||||
raise ValueError(f"{var.variable} is required in input form")
|
if any(
|
||||||
if not var.required and not user_input_value:
|
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
|
||||||
# TODO: should we return None here if the default value is None?
|
|
||||||
return var.default or ""
|
|
||||||
if (
|
|
||||||
var.type
|
|
||||||
in {
|
|
||||||
VariableEntityType.TEXT_INPUT,
|
|
||||||
VariableEntityType.SELECT,
|
|
||||||
VariableEntityType.PARAGRAPH,
|
|
||||||
}
|
|
||||||
and user_input_value
|
|
||||||
and not isinstance(user_input_value, str)
|
|
||||||
):
|
):
|
||||||
|
raise ValueError("Invalid input type")
|
||||||
|
|
||||||
|
return user_inputs
|
||||||
|
|
||||||
|
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||||
|
user_input_value = inputs.get(var.variable)
|
||||||
|
if not user_input_value:
|
||||||
|
if var.required:
|
||||||
|
raise ValueError(f"{var.variable} is required in input form")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if var.type in {
|
||||||
|
VariableEntityType.TEXT_INPUT,
|
||||||
|
VariableEntityType.SELECT,
|
||||||
|
VariableEntityType.PARAGRAPH,
|
||||||
|
} and not isinstance(user_input_value, str):
|
||||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||||
# may raise ValueError if user_input_value is not a valid number
|
# may raise ValueError if user_input_value is not a valid number
|
||||||
|
@ -41,12 +98,24 @@ class BaseAppGenerator:
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||||
if var.type == VariableEntityType.SELECT:
|
if var.type == VariableEntityType.SELECT:
|
||||||
options = var.options or []
|
options = var.options
|
||||||
if user_input_value not in options:
|
if user_input_value not in options:
|
||||||
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
raise ValueError(f"{var.variable} in input form must be one of the following: {options}")
|
||||||
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
elif var.type in {VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH}:
|
||||||
if var.max_length and user_input_value and len(user_input_value) > var.max_length:
|
if var.max_length and len(user_input_value) > var.max_length:
|
||||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||||
|
elif var.type == VariableEntityType.FILE:
|
||||||
|
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||||
|
raise ValueError(f"{var.variable} in input form must be a file")
|
||||||
|
elif var.type == VariableEntityType.FILE_LIST:
|
||||||
|
if not (
|
||||||
|
isinstance(user_input_value, list)
|
||||||
|
and (
|
||||||
|
all(isinstance(item, dict) for item in user_input_value)
|
||||||
|
or all(isinstance(item, File) for item in user_input_value)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||||
|
|
||||||
return user_input_value
|
return user_input_value
|
||||||
|
|
||||||
|
|
|
@ -27,7 +27,7 @@ from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
|
||||||
from models.model import App, AppMode, Message, MessageAnnotation
|
from models.model import App, AppMode, Message, MessageAnnotation
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
class AppRunner:
|
class AppRunner:
|
||||||
|
@ -37,7 +37,7 @@ class AppRunner:
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -137,7 +137,7 @@ class AppRunner:
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
query: Optional[str] = None,
|
query: Optional[str] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
memory: Optional[TokenBufferMemory] = None,
|
memory: Optional[TokenBufferMemory] = None,
|
||||||
|
|
|
@ -17,10 +17,11 @@ from core.app.apps.chat.generate_response_converter import ChatAppGenerateRespon
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import App, EndUser
|
from models.model import App, EndUser
|
||||||
|
|
||||||
|
@ -99,12 +100,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
override_model_config_dict["retriever_resource"] = {"enabled": True}
|
||||||
|
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
@ -117,7 +125,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
)
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
trace_manager = TraceQueueManager(app_model.id)
|
trace_manager = TraceQueueManager(app_id=app_model.id)
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = ChatAppGenerateEntity(
|
application_generate_entity = ChatAppGenerateEntity(
|
||||||
|
@ -125,15 +133,17 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
conversation_id=conversation.id if conversation else None,
|
conversation_id=conversation.id if conversation else None,
|
||||||
inputs=conversation.inputs if conversation else self._get_cleaned_inputs(inputs, app_config),
|
inputs=conversation.inputs
|
||||||
|
if conversation
|
||||||
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id"),
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
extras=extras,
|
extras=extras,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
# init generate records
|
# init generate records
|
||||||
|
|
|
@ -17,12 +17,12 @@ from core.app.apps.completion.generate_response_converter import CompletionAppGe
|
||||||
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from factories import file_factory
|
||||||
from models.model import App, EndUser, Message
|
from models import Account, App, EndUser, Message
|
||||||
from services.errors.app import MoreLikeThisDisabledError
|
from services.errors.app import MoreLikeThisDisabledError
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
|
|
||||||
|
@ -88,12 +88,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
tenant_id=app_model.tenant_id, config=args.get("model_config")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
files = args["files"] if args.get("files") else []
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
@ -103,6 +110,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
)
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
|
user_id = user.id if isinstance(user, Account) else user.session_id
|
||||||
trace_manager = TraceQueueManager(app_model.id)
|
trace_manager = TraceQueueManager(app_model.id)
|
||||||
|
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
|
@ -110,7 +118,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
model_conf=ModelConfigConverter.convert(app_config),
|
model_conf=ModelConfigConverter.convert(app_config),
|
||||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
|
@ -251,10 +259,16 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||||
override_model_config_dict["model"] = model_dict
|
override_model_config_dict["model"] = model_dict
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
|
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict)
|
||||||
if file_extra_config:
|
if file_extra_config:
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(message.files, file_extra_config, user)
|
file_objs = file_factory.build_from_mappings(
|
||||||
|
mappings=message.message_files,
|
||||||
|
tenant_id=app_model.tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ from core.app.entities.task_entities import (
|
||||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models import Account
|
||||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||||
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
|
||||||
|
@ -235,13 +235,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||||
for file in application_generate_entity.files:
|
for file in application_generate_entity.files:
|
||||||
message_file = MessageFile(
|
message_file = MessageFile(
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
type=file.type.value,
|
type=file.type,
|
||||||
transfer_method=file.transfer_method.value,
|
transfer_method=file.transfer_method,
|
||||||
belongs_to="user",
|
belongs_to="user",
|
||||||
url=file.url,
|
url=file.remote_url,
|
||||||
upload_file_id=file.related_id,
|
upload_file_id=file.related_id,
|
||||||
created_by_role=("account" if account_id else "end_user"),
|
created_by_role=("account" if account_id else "end_user"),
|
||||||
created_by=account_id or end_user_id,
|
created_by=account_id or end_user_id or "",
|
||||||
)
|
)
|
||||||
db.session.add(message_file)
|
db.session.add(message_file)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, Union, overload
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
|
@ -20,13 +20,12 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
from core.file.message_file_parser import MessageFileParser
|
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from factories import file_factory
|
||||||
from models.model import App, EndUser
|
from models import Account, App, EndUser, Workflow
|
||||||
from models.workflow import Workflow
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -63,49 +62,46 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
args: dict,
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_thread_pool_id: Optional[str] = None,
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
Generate App response.
|
|
||||||
|
|
||||||
:param app_model: App
|
role = CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER
|
||||||
:param workflow: Workflow
|
|
||||||
:param user: account or end user
|
|
||||||
:param args: request args
|
|
||||||
:param invoke_from: invoke from source
|
|
||||||
:param stream: is stream
|
|
||||||
:param call_depth: call depth
|
|
||||||
:param workflow_thread_pool_id: workflow thread pool id
|
|
||||||
"""
|
|
||||||
inputs = args["inputs"]
|
|
||||||
|
|
||||||
# parse files
|
# parse files
|
||||||
files = args["files"] if args.get("files") else []
|
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
|
||||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||||
if file_extra_config:
|
system_files = file_factory.build_from_mappings(
|
||||||
file_objs = message_file_parser.validate_and_transform_files_arg(files, file_extra_config, user)
|
mappings=files,
|
||||||
else:
|
tenant_id=app_model.tenant_id,
|
||||||
file_objs = []
|
user_id=user.id,
|
||||||
|
role=role,
|
||||||
|
config=file_extra_config,
|
||||||
|
)
|
||||||
|
|
||||||
# convert to app config
|
# convert to app config
|
||||||
app_config = WorkflowAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
app_config = WorkflowAppConfigManager.get_app_config(
|
||||||
|
app_model=app_model,
|
||||||
|
workflow=workflow,
|
||||||
|
)
|
||||||
|
|
||||||
# get tracing instance
|
# get tracing instance
|
||||||
user_id = user.id if isinstance(user, Account) else user.session_id
|
trace_manager = TraceQueueManager(
|
||||||
trace_manager = TraceQueueManager(app_model.id, user_id)
|
app_id=app_model.id,
|
||||||
|
user_id=user.id if isinstance(user, Account) else user.session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs: Mapping[str, Any] = args["inputs"]
|
||||||
workflow_run_id = str(uuid.uuid4())
|
workflow_run_id = str(uuid.uuid4())
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = WorkflowAppGenerateEntity(
|
application_generate_entity = WorkflowAppGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
app_config=app_config,
|
app_config=app_config,
|
||||||
inputs=self._get_cleaned_inputs(inputs, app_config),
|
inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
files=file_objs,
|
files=system_files,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
|
|
@ -1,20 +1,19 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfig
|
||||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
from core.app.apps.workflow_logging_callback import WorkflowLoggingCallback
|
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
WorkflowAppGenerateEntity,
|
WorkflowAppGenerateEntity,
|
||||||
)
|
)
|
||||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
from core.workflow.entities.node_entities import UserFrom
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import UserFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App, EndUser
|
from models.model import App, EndUser
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
@ -71,7 +70,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
workflow_callbacks: list[WorkflowCallback] = []
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
if dify_config.DEBUG:
|
||||||
workflow_callbacks.append(WorkflowLoggingCallback())
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
# if only single iteration run is requested
|
# if only single iteration run is requested
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
@ -334,9 +333,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
||||||
start_at=graph_runtime_state.start_at,
|
start_at=graph_runtime_state.start_at,
|
||||||
total_tokens=graph_runtime_state.total_tokens,
|
total_tokens=graph_runtime_state.total_tokens,
|
||||||
total_steps=graph_runtime_state.node_run_steps,
|
total_steps=graph_runtime_state.node_run_steps,
|
||||||
outputs=json.dumps(event.outputs)
|
outputs=event.outputs,
|
||||||
if isinstance(event, QueueWorkflowSucceededEvent) and event.outputs
|
|
||||||
else None,
|
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
)
|
)
|
||||||
|
|
|
@ -20,7 +20,6 @@ from core.app.entities.queue_entities import (
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.workflow.entities.node_entities import NodeType
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
from core.workflow.graph_engine.entities.event import (
|
from core.workflow.graph_engine.entities.event import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
|
@ -45,6 +44,7 @@ from core.workflow.nodes.base_node import BaseNode
|
||||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||||
from core.workflow.nodes.node_mapping import node_classes
|
from core.workflow.nodes.node_mapping import node_classes
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import NodeType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import App
|
from models.model import App
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ class InvokeFrom(Enum):
|
||||||
DEBUGGER = "debugger"
|
DEBUGGER = "debugger"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def value_of(cls, value: str) -> "InvokeFrom":
|
def value_of(cls, value: str):
|
||||||
"""
|
"""
|
||||||
Get value of given mode.
|
Get value of given mode.
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ class AppGenerateEntity(BaseModel):
|
||||||
app_config: AppConfig
|
app_config: AppConfig
|
||||||
|
|
||||||
inputs: Mapping[str, Any]
|
inputs: Mapping[str, Any]
|
||||||
files: list[FileVar] = []
|
files: Sequence[File]
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
# extras
|
# extras
|
||||||
|
|
|
@ -6,8 +6,9 @@ from pydantic import BaseModel, field_validator
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType
|
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||||
|
from enums import NodeType
|
||||||
|
|
||||||
|
|
||||||
class QueueEvent(str, Enum):
|
class QueueEvent(str, Enum):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
@ -119,6 +120,7 @@ class MessageEndStreamResponse(StreamResponse):
|
||||||
event: StreamEvent = StreamEvent.MESSAGE_END
|
event: StreamEvent = StreamEvent.MESSAGE_END
|
||||||
id: str
|
id: str
|
||||||
metadata: dict = {}
|
metadata: dict = {}
|
||||||
|
files: Optional[Sequence[Mapping[str, Any]]] = None
|
||||||
|
|
||||||
|
|
||||||
class MessageFileStreamResponse(StreamResponse):
|
class MessageFileStreamResponse(StreamResponse):
|
||||||
|
|
|
@ -1,18 +0,0 @@
|
||||||
import re
|
|
||||||
|
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
|
||||||
|
|
||||||
from . import SegmentGroup, factory
|
|
||||||
|
|
||||||
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
|
|
||||||
|
|
||||||
|
|
||||||
def convert_template(*, template: str, variable_pool: VariablePool):
|
|
||||||
parts = re.split(VARIABLE_PATTERN, template)
|
|
||||||
segments = []
|
|
||||||
for part in filter(lambda x: x, parts):
|
|
||||||
if "." in part and (value := variable_pool.get(part.split("."))):
|
|
||||||
segments.append(value)
|
|
||||||
else:
|
|
||||||
segments.append(factory.build_segment(part))
|
|
||||||
return SegmentGroup(value=segments)
|
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional, Union, cast
|
from typing import Any, Optional, Union, cast
|
||||||
|
|
||||||
|
@ -27,15 +28,15 @@ from core.app.entities.task_entities import (
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
WorkflowTaskState,
|
WorkflowTaskState,
|
||||||
)
|
)
|
||||||
from core.file.file_obj import FileVar
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.entities.trace_entity import TraceTaskName
|
from core.ops.entities.trace_entity import TraceTaskName
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.workflow.entities.node_entities import NodeType
|
|
||||||
from core.workflow.enums import SystemVariableKey
|
from core.workflow.enums import SystemVariableKey
|
||||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from enums import NodeType, WorkflowRunTriggeredFrom
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.account import Account
|
from models.account import Account
|
||||||
from models.model import EndUser
|
from models.model import EndUser
|
||||||
|
@ -47,7 +48,6 @@ from models.workflow import (
|
||||||
WorkflowNodeExecutionTriggeredFrom,
|
WorkflowNodeExecutionTriggeredFrom,
|
||||||
WorkflowRun,
|
WorkflowRun,
|
||||||
WorkflowRunStatus,
|
WorkflowRunStatus,
|
||||||
WorkflowRunTriggeredFrom,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ class WorkflowCycleManage:
|
||||||
start_at: float,
|
start_at: float,
|
||||||
total_tokens: int,
|
total_tokens: int,
|
||||||
total_steps: int,
|
total_steps: int,
|
||||||
outputs: Optional[str] = None,
|
outputs: Mapping[str, Any] | None = None,
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
trace_manager: Optional[TraceQueueManager] = None,
|
trace_manager: Optional[TraceQueueManager] = None,
|
||||||
) -> WorkflowRun:
|
) -> WorkflowRun:
|
||||||
|
@ -133,8 +133,10 @@ class WorkflowCycleManage:
|
||||||
"""
|
"""
|
||||||
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
workflow_run = self._refetch_workflow_run(workflow_run.id)
|
||||||
|
|
||||||
|
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||||
|
|
||||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||||
workflow_run.outputs = outputs
|
workflow_run.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||||
workflow_run.total_tokens = total_tokens
|
workflow_run.total_tokens = total_tokens
|
||||||
workflow_run.total_steps = total_steps
|
workflow_run.total_steps = total_steps
|
||||||
|
@ -286,10 +288,11 @@ class WorkflowCycleManage:
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_node_execution.execution_metadata = execution_metadata
|
workflow_node_execution.execution_metadata = execution_metadata
|
||||||
workflow_node_execution.finished_at = finished_at
|
workflow_node_execution.finished_at = finished_at
|
||||||
|
@ -326,11 +329,12 @@ class WorkflowCycleManage:
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||||
|
|
||||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||||
workflow_node_execution.error = event.error
|
workflow_node_execution.error = event.error
|
||||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||||
workflow_node_execution.process_data = json.dumps(event.process_data) if event.process_data else None
|
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||||
workflow_node_execution.finished_at = finished_at
|
workflow_node_execution.finished_at = finished_at
|
||||||
workflow_node_execution.elapsed_time = elapsed_time
|
workflow_node_execution.elapsed_time = elapsed_time
|
||||||
|
@ -637,7 +641,7 @@ class WorkflowCycleManage:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> Sequence[Mapping[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch files from node outputs
|
Fetch files from node outputs
|
||||||
:param outputs_dict: node outputs dict
|
:param outputs_dict: node outputs dict
|
||||||
|
@ -646,15 +650,15 @@ class WorkflowCycleManage:
|
||||||
if not outputs_dict:
|
if not outputs_dict:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
files = []
|
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||||
for output_var, output_value in outputs_dict.items():
|
# Remove None
|
||||||
file_vars = self._fetch_files_from_variable_value(output_value)
|
files = [file for file in files if file]
|
||||||
if file_vars:
|
# Flatten list
|
||||||
files.extend(file_vars)
|
files = [file for sublist in files for file in sublist]
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
|
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Fetch files from variable value
|
Fetch files from variable value
|
||||||
:param value: variable value
|
:param value: variable value
|
||||||
|
@ -666,17 +670,17 @@ class WorkflowCycleManage:
|
||||||
files = []
|
files = []
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
for item in value:
|
for item in value:
|
||||||
file_var = self._get_file_var_from_value(item)
|
file = self._get_file_var_from_value(item)
|
||||||
if file_var:
|
if file:
|
||||||
files.append(file_var)
|
files.append(file)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
file_var = self._get_file_var_from_value(value)
|
file = self._get_file_var_from_value(value)
|
||||||
if file_var:
|
if file:
|
||||||
files.append(file_var)
|
files.append(file)
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
|
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
Get file var from value
|
Get file var from value
|
||||||
:param value: variable value
|
:param value: variable value
|
||||||
|
@ -685,14 +689,11 @@ class WorkflowCycleManage:
|
||||||
if not value:
|
if not value:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||||
if "__variant" in value and value["__variant"] == FileVar.__name__:
|
return value
|
||||||
return value
|
elif isinstance(value, File):
|
||||||
elif isinstance(value, FileVar):
|
|
||||||
return value.to_dict()
|
return value.to_dict()
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun:
|
||||||
"""
|
"""
|
||||||
Refetch workflow run
|
Refetch workflow run
|
||||||
|
|
|
@ -1,29 +0,0 @@
|
||||||
import enum
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageFileType(enum.Enum):
|
|
||||||
IMAGE = "image"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in PromptMessageFileType:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class PromptMessageFile(BaseModel):
|
|
||||||
type: PromptMessageFileType
|
|
||||||
data: Any = None
|
|
||||||
|
|
||||||
|
|
||||||
class ImagePromptMessageFile(PromptMessageFile):
|
|
||||||
class DETAIL(enum.Enum):
|
|
||||||
LOW = "low"
|
|
||||||
HIGH = "high"
|
|
||||||
|
|
||||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
|
||||||
detail: DETAIL = DETAIL.LOW
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
from .constants import FILE_MODEL_IDENTITY
|
||||||
|
from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType
|
||||||
|
from .models import (
|
||||||
|
File,
|
||||||
|
FileExtraConfig,
|
||||||
|
ImageConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"FileType",
|
||||||
|
"FileExtraConfig",
|
||||||
|
"FileTransferMethod",
|
||||||
|
"FileBelongsTo",
|
||||||
|
"File",
|
||||||
|
"ImageConfig",
|
||||||
|
"FileAttribute",
|
||||||
|
"ArrayFileAttribute",
|
||||||
|
"FILE_MODEL_IDENTITY",
|
||||||
|
]
|
1
api/core/file/constants.py
Normal file
1
api/core/file/constants.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
FILE_MODEL_IDENTITY = "__dify__file__"
|
55
api/core/file/enums.py
Normal file
55
api/core/file/enums.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class FileType(str, Enum):
|
||||||
|
IMAGE = "image"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
AUDIO = "audio"
|
||||||
|
VIDEO = "video"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileType:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileTransferMethod(str, Enum):
|
||||||
|
REMOTE_URL = "remote_url"
|
||||||
|
LOCAL_FILE = "local_file"
|
||||||
|
TOOL_FILE = "tool_file"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileTransferMethod:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileBelongsTo(str, Enum):
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def value_of(value):
|
||||||
|
for member in FileBelongsTo:
|
||||||
|
if member.value == value:
|
||||||
|
return member
|
||||||
|
raise ValueError(f"No matching enum found for value '{value}'")
|
||||||
|
|
||||||
|
|
||||||
|
class FileAttribute(str, Enum):
|
||||||
|
TYPE = "type"
|
||||||
|
SIZE = "size"
|
||||||
|
NAME = "name"
|
||||||
|
MIME_TYPE = "mime_type"
|
||||||
|
TRANSFER_METHOD = "transfer_method"
|
||||||
|
URL = "url"
|
||||||
|
EXTENSION = "extension"
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFileAttribute(str, Enum):
|
||||||
|
LENGTH = "length"
|
136
api/core/file/file_manager.py
Normal file
136
api/core/file/file_manager.py
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
import base64
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models import UploadFile
|
||||||
|
|
||||||
|
from . import helpers
|
||||||
|
from .enums import FileAttribute
|
||||||
|
from .models import File, FileTransferMethod, FileType
|
||||||
|
from .tool_file_parser import ToolFileParser
|
||||||
|
|
||||||
|
|
||||||
|
def get_attr(*, file: "File", attr: "FileAttribute"):
|
||||||
|
match attr:
|
||||||
|
case FileAttribute.TYPE:
|
||||||
|
return file.type.value
|
||||||
|
case FileAttribute.SIZE:
|
||||||
|
return file.size
|
||||||
|
case FileAttribute.NAME:
|
||||||
|
return file.filename
|
||||||
|
case FileAttribute.MIME_TYPE:
|
||||||
|
return file.mime_type
|
||||||
|
case FileAttribute.TRANSFER_METHOD:
|
||||||
|
return file.transfer_method.value
|
||||||
|
case FileAttribute.URL:
|
||||||
|
return file.remote_url
|
||||||
|
case FileAttribute.EXTENSION:
|
||||||
|
return file.extension
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid file attribute: {attr}")
|
||||||
|
|
||||||
|
|
||||||
|
def to_prompt_message_content(file: "File", /):
|
||||||
|
"""
|
||||||
|
Convert a File object to an ImagePromptMessageContent object.
|
||||||
|
|
||||||
|
This function takes a File object and converts it to an ImagePromptMessageContent
|
||||||
|
object, which can be used as a prompt for image-based AI models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (File): The File object to convert. Must be of type FileType.IMAGE.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImagePromptMessageContent: An object containing the image data and detail level.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the file is not an image or if the file data is missing.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The detail level of the image prompt is determined by the file's extra_config.
|
||||||
|
If not specified, it defaults to ImagePromptMessageContent.DETAIL.LOW.
|
||||||
|
"""
|
||||||
|
if file.type != FileType.IMAGE:
|
||||||
|
raise ValueError("Only image file can convert to prompt message content")
|
||||||
|
|
||||||
|
url_or_b64_data = _get_url_or_b64_data(file=file)
|
||||||
|
if url_or_b64_data is None:
|
||||||
|
raise ValueError("Missing file data")
|
||||||
|
|
||||||
|
# decide the detail of image prompt message content
|
||||||
|
if file._extra_config and file._extra_config.image_config and file._extra_config.image_config.detail:
|
||||||
|
detail = file._extra_config.image_config.detail
|
||||||
|
else:
|
||||||
|
detail = ImagePromptMessageContent.DETAIL.LOW
|
||||||
|
|
||||||
|
return ImagePromptMessageContent(data=url_or_b64_data, detail=detail)
|
||||||
|
|
||||||
|
|
||||||
|
def download(*, upload_file_id: str, tenant_id: str):
|
||||||
|
upload_file = (
|
||||||
|
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
raise ValueError("upload file not found")
|
||||||
|
|
||||||
|
return _download(upload_file.key)
|
||||||
|
|
||||||
|
|
||||||
|
def _download(path: str, /):
|
||||||
|
"""
|
||||||
|
Download and return the contents of a file as bytes.
|
||||||
|
|
||||||
|
This function loads the file from storage and ensures it's in bytes format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (str): The path to the file in storage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bytes: The contents of the file as a bytes object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the loaded file is not a bytes object.
|
||||||
|
"""
|
||||||
|
data = storage.load(path, stream=False)
|
||||||
|
if not isinstance(data, bytes):
|
||||||
|
raise ValueError(f"file {path} is not a bytes object")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _get_base64(*, upload_file_id: str, tenant_id: str) -> str | None:
|
||||||
|
upload_file = (
|
||||||
|
db.session.query(UploadFile).filter(UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id).first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = _download(upload_file.key)
|
||||||
|
if data is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||||
|
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url_or_b64_data(file: "File"):
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
|
if file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return file.remote_url
|
||||||
|
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if file.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
|
||||||
|
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url":
|
||||||
|
return helpers.get_signed_image_url(upload_file_id=file.related_id)
|
||||||
|
return _get_base64(upload_file_id=file.related_id, tenant_id=file.tenant_id)
|
||||||
|
elif file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
# add sign url
|
||||||
|
if file.related_id is None or file.extension is None:
|
||||||
|
raise ValueError("Missing file related_id or extension")
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=file.related_id, extension=file.extension
|
||||||
|
)
|
|
@ -1,145 +0,0 @@
|
||||||
import enum
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.file.tool_file_parser import ToolFileParser
|
|
||||||
from core.file.upload_file_parser import UploadFileParser
|
|
||||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
|
||||||
from extensions.ext_database import db
|
|
||||||
|
|
||||||
|
|
||||||
class FileExtraConfig(BaseModel):
|
|
||||||
"""
|
|
||||||
File Upload Entity.
|
|
||||||
"""
|
|
||||||
|
|
||||||
image_config: Optional[dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class FileType(enum.Enum):
|
|
||||||
IMAGE = "image"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileType:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileTransferMethod(enum.Enum):
|
|
||||||
REMOTE_URL = "remote_url"
|
|
||||||
LOCAL_FILE = "local_file"
|
|
||||||
TOOL_FILE = "tool_file"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileTransferMethod:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileBelongsTo(enum.Enum):
|
|
||||||
USER = "user"
|
|
||||||
ASSISTANT = "assistant"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def value_of(value):
|
|
||||||
for member in FileBelongsTo:
|
|
||||||
if member.value == value:
|
|
||||||
return member
|
|
||||||
raise ValueError(f"No matching enum found for value '{value}'")
|
|
||||||
|
|
||||||
|
|
||||||
class FileVar(BaseModel):
|
|
||||||
id: Optional[str] = None # message file id
|
|
||||||
tenant_id: str
|
|
||||||
type: FileType
|
|
||||||
transfer_method: FileTransferMethod
|
|
||||||
url: Optional[str] = None # remote url
|
|
||||||
related_id: Optional[str] = None
|
|
||||||
extra_config: Optional[FileExtraConfig] = None
|
|
||||||
filename: Optional[str] = None
|
|
||||||
extension: Optional[str] = None
|
|
||||||
mime_type: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"__variant": self.__class__.__name__,
|
|
||||||
"tenant_id": self.tenant_id,
|
|
||||||
"type": self.type.value,
|
|
||||||
"transfer_method": self.transfer_method.value,
|
|
||||||
"url": self.preview_url,
|
|
||||||
"remote_url": self.url,
|
|
||||||
"related_id": self.related_id,
|
|
||||||
"filename": self.filename,
|
|
||||||
"extension": self.extension,
|
|
||||||
"mime_type": self.mime_type,
|
|
||||||
}
|
|
||||||
|
|
||||||
def to_markdown(self) -> str:
|
|
||||||
"""
|
|
||||||
Convert file to markdown
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
preview_url = self.preview_url
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
text = f'![{self.filename or ""}]({preview_url})'
|
|
||||||
else:
|
|
||||||
text = f"[{self.filename or preview_url}]({preview_url})"
|
|
||||||
|
|
||||||
return text
|
|
||||||
|
|
||||||
@property
|
|
||||||
def data(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get image data, file signed url or base64 data
|
|
||||||
depending on config MULTIMODAL_SEND_IMAGE_FORMAT
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_data()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def preview_url(self) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Get signed preview url
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
return self._get_data(force_url=True)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def prompt_message_content(self) -> ImagePromptMessageContent:
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
image_config = self.extra_config.image_config
|
|
||||||
|
|
||||||
return ImagePromptMessageContent(
|
|
||||||
data=self.data,
|
|
||||||
detail=ImagePromptMessageContent.DETAIL.HIGH
|
|
||||||
if image_config.get("detail") == "high"
|
|
||||||
else ImagePromptMessageContent.DETAIL.LOW,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
|
||||||
from models.model import UploadFile
|
|
||||||
|
|
||||||
if self.type == FileType.IMAGE:
|
|
||||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
|
||||||
return self.url
|
|
||||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
|
||||||
upload_file = (
|
|
||||||
db.session.query(UploadFile)
|
|
||||||
.filter(UploadFile.id == self.related_id, UploadFile.tenant_id == self.tenant_id)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
return UploadFileParser.get_image_data(upload_file=upload_file, force_url=force_url)
|
|
||||||
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
|
||||||
extension = self.extension
|
|
||||||
# add sign url
|
|
||||||
return ToolFileParser.get_tool_file_manager().sign_file(
|
|
||||||
tool_file_id=self.related_id, extension=extension
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
|
61
api/core/file/helpers.py
Normal file
61
api/core/file/helpers.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_signed_image_url(upload_file_id: str) -> str:
|
||||||
|
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/image-preview"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
key = dify_config.SECRET_KEY.encode()
|
||||||
|
msg = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_signed_file_url(upload_file_id: str) -> str:
|
||||||
|
url = f"{dify_config.FILES_URL}/files/{upload_file_id}/file-preview"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
key = dify_config.SECRET_KEY.encode()
|
||||||
|
msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode()
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode()
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
|
@ -1,243 +0,0 @@
|
||||||
import re
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from typing import Any, Union
|
|
||||||
from urllib.parse import parse_qs, urlparse
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from core.file.file_obj import FileBelongsTo, FileExtraConfig, FileTransferMethod, FileType, FileVar
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.account import Account
|
|
||||||
from models.model import EndUser, MessageFile, UploadFile
|
|
||||||
from services.file_service import IMAGE_EXTENSIONS
|
|
||||||
|
|
||||||
|
|
||||||
class MessageFileParser:
|
|
||||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
|
||||||
self.tenant_id = tenant_id
|
|
||||||
self.app_id = app_id
|
|
||||||
|
|
||||||
def validate_and_transform_files_arg(
|
|
||||||
self, files: Sequence[Mapping[str, Any]], file_extra_config: FileExtraConfig, user: Union[Account, EndUser]
|
|
||||||
) -> list[FileVar]:
|
|
||||||
"""
|
|
||||||
validate and transform files arg
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:param user:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
for file in files:
|
|
||||||
if not isinstance(file, dict):
|
|
||||||
raise ValueError("Invalid file format, must be dict")
|
|
||||||
if not file.get("type"):
|
|
||||||
raise ValueError("Missing file type")
|
|
||||||
FileType.value_of(file.get("type"))
|
|
||||||
if not file.get("transfer_method"):
|
|
||||||
raise ValueError("Missing file transfer method")
|
|
||||||
FileTransferMethod.value_of(file.get("transfer_method"))
|
|
||||||
if file.get("transfer_method") == FileTransferMethod.REMOTE_URL.value:
|
|
||||||
if not file.get("url"):
|
|
||||||
raise ValueError("Missing file url")
|
|
||||||
if not file.get("url").startswith("http"):
|
|
||||||
raise ValueError("Invalid file url")
|
|
||||||
if file.get("transfer_method") == FileTransferMethod.LOCAL_FILE.value and not file.get("upload_file_id"):
|
|
||||||
raise ValueError("Missing file upload_file_id")
|
|
||||||
if file.get("transform_method") == FileTransferMethod.TOOL_FILE.value and not file.get("tool_file_id"):
|
|
||||||
raise ValueError("Missing file tool_file_id")
|
|
||||||
|
|
||||||
# transform files to file objs
|
|
||||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
|
||||||
|
|
||||||
# validate files
|
|
||||||
new_files = []
|
|
||||||
for file_type, file_objs in type_file_objs.items():
|
|
||||||
if file_type == FileType.IMAGE:
|
|
||||||
# parse and validate files
|
|
||||||
image_config = file_extra_config.image_config
|
|
||||||
|
|
||||||
# check if image file feature is enabled
|
|
||||||
if not image_config:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Validate number of files
|
|
||||||
if len(files) > image_config["number_limits"]:
|
|
||||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
|
||||||
|
|
||||||
for file_obj in file_objs:
|
|
||||||
# Validate transfer method
|
|
||||||
if file_obj.transfer_method.value not in image_config["transfer_methods"]:
|
|
||||||
raise ValueError(f"Invalid transfer method: {file_obj.transfer_method.value}")
|
|
||||||
|
|
||||||
# Validate file type
|
|
||||||
if file_obj.type != FileType.IMAGE:
|
|
||||||
raise ValueError(f"Invalid file type: {file_obj.type}")
|
|
||||||
|
|
||||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
|
||||||
# check remote url valid and is image
|
|
||||||
result, error = self._check_image_remote_url(file_obj.url)
|
|
||||||
if result is False:
|
|
||||||
raise ValueError(error)
|
|
||||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
|
||||||
# get upload file from upload_file_id
|
|
||||||
upload_file = (
|
|
||||||
db.session.query(UploadFile)
|
|
||||||
.filter(
|
|
||||||
UploadFile.id == file_obj.related_id,
|
|
||||||
UploadFile.tenant_id == self.tenant_id,
|
|
||||||
UploadFile.created_by == user.id,
|
|
||||||
UploadFile.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
|
||||||
UploadFile.extension.in_(IMAGE_EXTENSIONS),
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
|
|
||||||
# check upload file is belong to tenant and user
|
|
||||||
if not upload_file:
|
|
||||||
raise ValueError("Invalid upload file")
|
|
||||||
|
|
||||||
new_files.append(file_obj)
|
|
||||||
|
|
||||||
# return all file objs
|
|
||||||
return new_files
|
|
||||||
|
|
||||||
def transform_message_files(self, files: list[MessageFile], file_extra_config: FileExtraConfig):
|
|
||||||
"""
|
|
||||||
transform message files
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# transform files to file objs
|
|
||||||
type_file_objs = self._to_file_objs(files, file_extra_config)
|
|
||||||
|
|
||||||
# return all file objs
|
|
||||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
|
||||||
|
|
||||||
def _to_file_objs(
|
|
||||||
self, files: list[Union[dict, MessageFile]], file_extra_config: FileExtraConfig
|
|
||||||
) -> dict[FileType, list[FileVar]]:
|
|
||||||
"""
|
|
||||||
transform files to file objs
|
|
||||||
|
|
||||||
:param files:
|
|
||||||
:param file_extra_config:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
type_file_objs: dict[FileType, list[FileVar]] = {
|
|
||||||
# Currently only support image
|
|
||||||
FileType.IMAGE: []
|
|
||||||
}
|
|
||||||
|
|
||||||
if not files:
|
|
||||||
return type_file_objs
|
|
||||||
|
|
||||||
# group by file type and convert file args or message files to FileObj
|
|
||||||
for file in files:
|
|
||||||
if isinstance(file, MessageFile):
|
|
||||||
if file.belongs_to == FileBelongsTo.ASSISTANT.value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_obj = self._to_file_obj(file, file_extra_config)
|
|
||||||
if file_obj.type not in type_file_objs:
|
|
||||||
continue
|
|
||||||
|
|
||||||
type_file_objs[file_obj.type].append(file_obj)
|
|
||||||
|
|
||||||
return type_file_objs
|
|
||||||
|
|
||||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_extra_config: FileExtraConfig):
|
|
||||||
"""
|
|
||||||
transform file to file obj
|
|
||||||
|
|
||||||
:param file:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
if isinstance(file, dict):
|
|
||||||
transfer_method = FileTransferMethod.value_of(file.get("transfer_method"))
|
|
||||||
if transfer_method != FileTransferMethod.TOOL_FILE:
|
|
||||||
return FileVar(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.get("type")),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
url=file.get("url") if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
|
||||||
related_id=file.get("upload_file_id") if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
return FileVar(
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.get("type")),
|
|
||||||
transfer_method=transfer_method,
|
|
||||||
url=None,
|
|
||||||
related_id=file.get("tool_file_id"),
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return FileVar(
|
|
||||||
id=file.id,
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
type=FileType.value_of(file.type),
|
|
||||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
|
||||||
url=file.url,
|
|
||||||
related_id=file.upload_file_id or None,
|
|
||||||
extra_config=file_extra_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_image_remote_url(self, url):
|
|
||||||
try:
|
|
||||||
headers = {
|
|
||||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko)"
|
|
||||||
" Chrome/91.0.4472.124 Safari/537.36"
|
|
||||||
}
|
|
||||||
|
|
||||||
def is_s3_presigned_url(url):
|
|
||||||
try:
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
if "amazonaws.com" not in parsed_url.netloc:
|
|
||||||
return False
|
|
||||||
query_params = parse_qs(parsed_url.query)
|
|
||||||
|
|
||||||
def check_presign_v2(query_params):
|
|
||||||
required_params = ["Signature", "Expires"]
|
|
||||||
for param in required_params:
|
|
||||||
if param not in query_params:
|
|
||||||
return False
|
|
||||||
if not query_params["Expires"][0].isdigit():
|
|
||||||
return False
|
|
||||||
signature = query_params["Signature"][0]
|
|
||||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def check_presign_v4(query_params):
|
|
||||||
required_params = ["X-Amz-Signature", "X-Amz-Expires"]
|
|
||||||
for param in required_params:
|
|
||||||
if param not in query_params:
|
|
||||||
return False
|
|
||||||
if not query_params["X-Amz-Expires"][0].isdigit():
|
|
||||||
return False
|
|
||||||
signature = query_params["X-Amz-Signature"][0]
|
|
||||||
if not re.match(r"^[A-Za-z0-9+/]+={0,2}$", signature):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
return check_presign_v4(query_params) or check_presign_v2(query_params)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if is_s3_presigned_url(url):
|
|
||||||
response = requests.get(url, headers=headers, allow_redirects=True)
|
|
||||||
if response.status_code in {200, 304}:
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
|
||||||
if response.status_code in {200, 304}:
|
|
||||||
return True, ""
|
|
||||||
else:
|
|
||||||
return False, "URL does not exist."
|
|
||||||
except requests.RequestException as e:
|
|
||||||
return False, f"Error checking URL: {e}"
|
|
140
api/core/file/models.py
Normal file
140
api/core/file/models.py
Normal file
|
@ -0,0 +1,140 @@
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||||
|
|
||||||
|
from . import helpers
|
||||||
|
from .constants import FILE_MODEL_IDENTITY
|
||||||
|
from .enums import FileTransferMethod, FileType
|
||||||
|
from .tool_file_parser import ToolFileParser
|
||||||
|
|
||||||
|
|
||||||
|
class ImageConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||||
|
"""
|
||||||
|
|
||||||
|
number_limits: int = 0
|
||||||
|
transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
detail: ImagePromptMessageContent.DETAIL | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class FileExtraConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
File Upload Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
image_config: Optional[ImageConfig] = None
|
||||||
|
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||||
|
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||||
|
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||||
|
number_limits: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class File(BaseModel):
|
||||||
|
dify_model_identity: str = FILE_MODEL_IDENTITY
|
||||||
|
|
||||||
|
id: Optional[str] = None # message file id
|
||||||
|
tenant_id: str
|
||||||
|
type: FileType
|
||||||
|
transfer_method: FileTransferMethod
|
||||||
|
remote_url: Optional[str] = None # remote url
|
||||||
|
related_id: Optional[str] = None
|
||||||
|
filename: Optional[str] = None
|
||||||
|
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
|
||||||
|
mime_type: Optional[str] = None
|
||||||
|
size: int = -1
|
||||||
|
_extra_config: FileExtraConfig | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Mapping[str, str | int | None]:
|
||||||
|
data = self.model_dump(mode="json")
|
||||||
|
return {
|
||||||
|
**data,
|
||||||
|
"url": self.generate_url(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
url = self.generate_url()
|
||||||
|
if self.type == FileType.IMAGE:
|
||||||
|
text = f'![{self.filename or ""}]({url})'
|
||||||
|
else:
|
||||||
|
text = f"[{self.filename or url}]({url})"
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def generate_url(self) -> Optional[str]:
|
||||||
|
if self.type == FileType.IMAGE:
|
||||||
|
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return self.remote_url
|
||||||
|
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if self.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
return helpers.get_signed_image_url(upload_file_id=self.related_id)
|
||||||
|
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert self.related_id is not None
|
||||||
|
assert self.extension is not None
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=self.related_id, extension=self.extension
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
return self.remote_url
|
||||||
|
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if self.related_id is None:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
return helpers.get_signed_file_url(upload_file_id=self.related_id)
|
||||||
|
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert self.related_id is not None
|
||||||
|
assert self.extension is not None
|
||||||
|
return ToolFileParser.get_tool_file_manager().sign_file(
|
||||||
|
tool_file_id=self.related_id, extension=self.extension
|
||||||
|
)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def validate_after(self):
|
||||||
|
match self.transfer_method:
|
||||||
|
case FileTransferMethod.REMOTE_URL:
|
||||||
|
if not self.remote_url:
|
||||||
|
raise ValueError("Missing file url")
|
||||||
|
if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"):
|
||||||
|
raise ValueError("Invalid file url")
|
||||||
|
case FileTransferMethod.LOCAL_FILE:
|
||||||
|
if not self.related_id:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
case FileTransferMethod.TOOL_FILE:
|
||||||
|
if not self.related_id:
|
||||||
|
raise ValueError("Missing file related_id")
|
||||||
|
|
||||||
|
# Validate the extra config.
|
||||||
|
if not self._extra_config:
|
||||||
|
return self
|
||||||
|
|
||||||
|
if self._extra_config.allowed_file_types:
|
||||||
|
if self.type not in self._extra_config.allowed_file_types and self.type != FileType.CUSTOM:
|
||||||
|
raise ValueError(f"Invalid file type: {self.type}")
|
||||||
|
|
||||||
|
if self._extra_config.allowed_extensions and self.extension not in self._extra_config.allowed_extensions:
|
||||||
|
raise ValueError(f"Invalid file extension: {self.extension}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._extra_config.allowed_upload_methods
|
||||||
|
and self.transfer_method not in self._extra_config.allowed_upload_methods
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||||
|
|
||||||
|
match self.type:
|
||||||
|
case FileType.IMAGE:
|
||||||
|
# NOTE: This part of validation is deprecated, but still used in app features "Image Upload".
|
||||||
|
if not self._extra_config.image_config:
|
||||||
|
return self
|
||||||
|
# TODO: skip check if transfer_methods is empty, because many test cases are not setting this field
|
||||||
|
if (
|
||||||
|
self._extra_config.image_config.transfer_methods
|
||||||
|
and self.transfer_method not in self._extra_config.image_config.transfer_methods
|
||||||
|
):
|
||||||
|
raise ValueError(f"Invalid transfer method: {self.transfer_method}")
|
||||||
|
|
||||||
|
return self
|
|
@ -1,4 +1,9 @@
|
||||||
tool_file_manager = {"manager": None}
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
|
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||||
|
|
||||||
|
|
||||||
class ToolFileParser:
|
class ToolFileParser:
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from extensions.ext_storage import storage
|
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = ["jpg", "jpeg", "png", "webp", "gif", "svg"]
|
|
||||||
IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS])
|
|
||||||
|
|
||||||
|
|
||||||
class UploadFileParser:
|
|
||||||
@classmethod
|
|
||||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
|
||||||
if not upload_file:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if upload_file.extension not in IMAGE_EXTENSIONS:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url" or force_url:
|
|
||||||
return cls.get_signed_temp_image_url(upload_file.id)
|
|
||||||
else:
|
|
||||||
# get image file base64
|
|
||||||
try:
|
|
||||||
data = storage.load(upload_file.key)
|
|
||||||
except FileNotFoundError:
|
|
||||||
logging.error(f"File not found: {upload_file.key}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
|
||||||
return f"data:{upload_file.mime_type};base64,{encoded_string}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_signed_temp_image_url(cls, upload_file_id) -> str:
|
|
||||||
"""
|
|
||||||
get signed url from upload file
|
|
||||||
|
|
||||||
:param upload_file: UploadFile object
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
base_url = dify_config.FILES_URL
|
|
||||||
image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
|
|
||||||
|
|
||||||
timestamp = str(int(time.time()))
|
|
||||||
nonce = os.urandom(16).hex()
|
|
||||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
|
||||||
secret_key = dify_config.SECRET_KEY.encode()
|
|
||||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
|
||||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
|
||||||
|
|
||||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
|
||||||
"""
|
|
||||||
verify signature
|
|
||||||
|
|
||||||
:param upload_file_id: file id
|
|
||||||
:param timestamp: timestamp
|
|
||||||
:param nonce: nonce
|
|
||||||
:param sign: signature
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
|
||||||
secret_key = dify_config.SECRET_KEY.encode()
|
|
||||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
|
||||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
|
||||||
|
|
||||||
# verify signature
|
|
||||||
if sign != recalculated_encoded_sign:
|
|
||||||
return False
|
|
||||||
|
|
||||||
current_time = int(time.time())
|
|
||||||
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
|
|
@ -13,8 +13,11 @@ SSRF_PROXY_HTTP_URL = os.getenv("SSRF_PROXY_HTTP_URL", "")
|
||||||
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
|
SSRF_PROXY_HTTPS_URL = os.getenv("SSRF_PROXY_HTTPS_URL", "")
|
||||||
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
|
SSRF_DEFAULT_MAX_RETRIES = int(os.getenv("SSRF_DEFAULT_MAX_RETRIES", "3"))
|
||||||
|
|
||||||
proxies = (
|
proxy_mounts = (
|
||||||
{"http://": SSRF_PROXY_HTTP_URL, "https://": SSRF_PROXY_HTTPS_URL}
|
{
|
||||||
|
"http://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTP_URL),
|
||||||
|
"https://": httpx.HTTPTransport(proxy=SSRF_PROXY_HTTPS_URL),
|
||||||
|
}
|
||||||
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
|
if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
@ -33,11 +36,14 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
|
||||||
while retries <= max_retries:
|
while retries <= max_retries:
|
||||||
try:
|
try:
|
||||||
if SSRF_PROXY_ALL_URL:
|
if SSRF_PROXY_ALL_URL:
|
||||||
response = httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
with httpx.Client(proxy=SSRF_PROXY_ALL_URL) as client:
|
||||||
elif proxies:
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
response = httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
elif proxy_mounts:
|
||||||
|
with httpx.Client(mounts=proxy_mounts) as client:
|
||||||
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
else:
|
else:
|
||||||
response = httpx.request(method=method, url=url, **kwargs)
|
with httpx.Client() as client:
|
||||||
|
response = client.request(method=method, url=url, **kwargs)
|
||||||
|
|
||||||
if response.status_code not in STATUS_FORCELIST:
|
if response.status_code not in STATUS_FORCELIST:
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -1,18 +1,20 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.file.message_file_parser import MessageFileParser
|
from core.file import file_manager
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
from models.model import AppMode, Conversation, Message, MessageFile
|
from models.model import AppMode, Conversation, Message, MessageFile
|
||||||
from models.workflow import WorkflowRun
|
from models.workflow import WorkflowRun
|
||||||
|
|
||||||
|
@ -65,7 +67,6 @@ class TokenBufferMemory:
|
||||||
|
|
||||||
messages = list(reversed(thread_messages))
|
messages = list(reversed(thread_messages))
|
||||||
|
|
||||||
message_file_parser = MessageFileParser(tenant_id=app_record.tenant_id, app_id=app_record.id)
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
|
||||||
|
@ -84,17 +85,20 @@ class TokenBufferMemory:
|
||||||
workflow_run.workflow.features_dict, is_vision=False
|
workflow_run.workflow.features_dict, is_vision=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if file_extra_config:
|
if file_extra_config and app_record:
|
||||||
file_objs = message_file_parser.transform_message_files(files, file_extra_config)
|
file_objs = file_factory.build_from_message_files(
|
||||||
|
message_files=files, tenant_id=app_record.tenant_id, config=file_extra_config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
file_objs = []
|
file_objs = []
|
||||||
|
|
||||||
if not file_objs:
|
if not file_objs:
|
||||||
prompt_messages.append(UserPromptMessage(content=message.query))
|
prompt_messages.append(UserPromptMessage(content=message.query))
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=message.query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
|
||||||
for file_obj in file_objs:
|
for file_obj in file_objs:
|
||||||
prompt_message_contents.append(file_obj.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file_obj))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections.abc import Callable, Generator, Sequence
|
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||||
from typing import IO, Optional, Union, cast
|
from typing import IO, Any, Optional, Union, cast
|
||||||
|
|
||||||
from core.embedding.embedding_constant import EmbeddingInputType
|
from core.embedding.embedding_constant import EmbeddingInputType
|
||||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||||
|
@ -274,7 +274,7 @@ class ModelInstance:
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> str:
|
def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: Optional[str] = None) -> Iterable[bytes]:
|
||||||
"""
|
"""
|
||||||
Invoke large language tts model
|
Invoke large language tts model
|
||||||
|
|
||||||
|
@ -298,7 +298,7 @@ class ModelInstance:
|
||||||
voice=voice,
|
voice=voice,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _round_robin_invoke(self, function: Callable, *args, **kwargs):
|
def _round_robin_invoke(self, function: Callable[..., Any], *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
Round-robin invoke
|
Round-robin invoke
|
||||||
:param function: function to invoke
|
:param function: function to invoke
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
|
from .message_entities import (
|
||||||
|
AssistantPromptMessage,
|
||||||
|
ImagePromptMessageContent,
|
||||||
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
|
PromptMessageContentType,
|
||||||
|
PromptMessageRole,
|
||||||
|
PromptMessageTool,
|
||||||
|
SystemPromptMessage,
|
||||||
|
TextPromptMessageContent,
|
||||||
|
ToolPromptMessage,
|
||||||
|
UserPromptMessage,
|
||||||
|
)
|
||||||
|
from .model_entities import ModelPropertyKey
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ImagePromptMessageContent",
|
||||||
|
"PromptMessage",
|
||||||
|
"PromptMessageRole",
|
||||||
|
"LLMUsage",
|
||||||
|
"ModelPropertyKey",
|
||||||
|
"AssistantPromptMessage",
|
||||||
|
"PromptMessage",
|
||||||
|
"PromptMessageContent",
|
||||||
|
"PromptMessageRole",
|
||||||
|
"SystemPromptMessage",
|
||||||
|
"TextPromptMessageContent",
|
||||||
|
"UserPromptMessage",
|
||||||
|
"PromptMessageTool",
|
||||||
|
"ToolPromptMessage",
|
||||||
|
"PromptMessageContentType",
|
||||||
|
"LLMResult",
|
||||||
|
"LLMResultChunk",
|
||||||
|
"LLMResultChunkDelta",
|
||||||
|
]
|
|
@ -79,7 +79,7 @@ class ImagePromptMessageContent(PromptMessageContent):
|
||||||
Model class for image prompt message content.
|
Model class for image prompt message content.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class DETAIL(Enum):
|
class DETAIL(str, Enum):
|
||||||
LOW = "low"
|
LOW = "low"
|
||||||
HIGH = "high"
|
HIGH = "high"
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
@ -8,6 +7,7 @@ from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
|
@ -77,7 +77,7 @@ class LargeLanguageModel(AIModel):
|
||||||
|
|
||||||
callbacks = callbacks or []
|
callbacks = callbacks or []
|
||||||
|
|
||||||
if bool(os.environ.get("DEBUG", "False").lower() == "true"):
|
if dify_config.DEBUG:
|
||||||
callbacks.append(LoggingCallback())
|
callbacks.append(LoggingCallback())
|
||||||
|
|
||||||
# trigger before invoke callbacks
|
# trigger before invoke callbacks
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
|
from collections.abc import Iterable
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
@ -22,8 +23,14 @@ class TTSModel(AIModel):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
self,
|
||||||
):
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Iterable[bytes]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
|
@ -50,8 +57,14 @@ class TTSModel(AIModel):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _invoke(
|
def _invoke(
|
||||||
self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None
|
self,
|
||||||
):
|
model: str,
|
||||||
|
tenant_id: str,
|
||||||
|
credentials: dict,
|
||||||
|
content_text: str,
|
||||||
|
voice: str,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Iterable[bytes]:
|
||||||
"""
|
"""
|
||||||
Invoke large language model
|
Invoke large language model
|
||||||
|
|
||||||
|
@ -68,25 +81,25 @@ class TTSModel(AIModel):
|
||||||
|
|
||||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||||
"""
|
"""
|
||||||
Get voice for given tts model voices
|
Retrieves the list of voices supported by a given text-to-speech (TTS) model.
|
||||||
|
|
||||||
:param language: tts language
|
:param language: The language for which the voices are requested.
|
||||||
:param model: model name
|
:param model: The name of the TTS model.
|
||||||
:param credentials: model credentials
|
:param credentials: The credentials required to access the TTS model.
|
||||||
:return: voices lists
|
:return: A list of voices supported by the TTS model.
|
||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties:
|
if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties:
|
||||||
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
raise ValueError("this model does not support voice")
|
||||||
if language:
|
|
||||||
return [
|
voices = model_schema.model_properties[ModelPropertyKey.VOICES]
|
||||||
{"name": d["name"], "value": d["mode"]}
|
if language:
|
||||||
for d in voices
|
return [
|
||||||
if language and language in d.get("language")
|
{"name": d["name"], "value": d["mode"]} for d in voices if language and language in d.get("language")
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
return [{"name": d["name"], "value": d["mode"]} for d in voices]
|
||||||
|
|
||||||
def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
|
def _get_model_default_voice(self, model: str, credentials: dict) -> Any:
|
||||||
"""
|
"""
|
||||||
|
@ -111,8 +124,10 @@ class TTSModel(AIModel):
|
||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.AUDIO_TYPE in model_schema.model_properties:
|
if not model_schema or ModelPropertyKey.AUDIO_TYPE not in model_schema.model_properties:
|
||||||
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
raise ValueError("this model does not support audio type")
|
||||||
|
|
||||||
|
return model_schema.model_properties[ModelPropertyKey.AUDIO_TYPE]
|
||||||
|
|
||||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -121,8 +136,10 @@ class TTSModel(AIModel):
|
||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.WORD_LIMIT in model_schema.model_properties:
|
if not model_schema or ModelPropertyKey.WORD_LIMIT not in model_schema.model_properties:
|
||||||
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
raise ValueError("this model does not support word limit")
|
||||||
|
|
||||||
|
return model_schema.model_properties[ModelPropertyKey.WORD_LIMIT]
|
||||||
|
|
||||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -131,8 +148,10 @@ class TTSModel(AIModel):
|
||||||
"""
|
"""
|
||||||
model_schema = self.get_model_schema(model, credentials)
|
model_schema = self.get_model_schema(model, credentials)
|
||||||
|
|
||||||
if model_schema and ModelPropertyKey.MAX_WORKERS in model_schema.model_properties:
|
if not model_schema or ModelPropertyKey.MAX_WORKERS not in model_schema.model_properties:
|
||||||
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
raise ValueError("this model does not support max workers")
|
||||||
|
|
||||||
|
return model_schema.model_properties[ModelPropertyKey.MAX_WORKERS]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
def _split_text_into_sentences(org_text, max_length=2000, pattern=r"[。.!?]"):
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
from typing import Optional, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
from core.file.file_obj import FileVar
|
from core.file import file_manager
|
||||||
|
from core.file.models import File
|
||||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
PromptMessageRole,
|
PromptMessageRole,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
|
@ -14,7 +17,6 @@ from core.model_runtime.entities.message_entities import (
|
||||||
)
|
)
|
||||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||||
from core.prompt.prompt_transform import PromptTransform
|
from core.prompt.prompt_transform import PromptTransform
|
||||||
from core.prompt.simple_prompt_transform import ModelMode
|
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,22 +30,19 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
|
|
||||||
def get_prompt(
|
def get_prompt(
|
||||||
self,
|
self,
|
||||||
prompt_template: Union[list[ChatModelMessage], CompletionModelPromptTemplate],
|
*,
|
||||||
inputs: dict,
|
prompt_template: Sequence[ChatModelMessage] | CompletionModelPromptTemplate,
|
||||||
|
inputs: dict[str, str],
|
||||||
query: str,
|
query: str,
|
||||||
files: list[FileVar],
|
files: Sequence[File],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
query_prompt_template: Optional[str] = None,
|
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
inputs = {key: str(value) for key, value in inputs.items()}
|
|
||||||
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
|
||||||
model_mode = ModelMode.value_of(model_config.mode)
|
if isinstance(prompt_template, CompletionModelPromptTemplate):
|
||||||
if model_mode == ModelMode.COMPLETION:
|
|
||||||
prompt_messages = self._get_completion_model_prompt_messages(
|
prompt_messages = self._get_completion_model_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
|
@ -54,12 +53,11 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
memory=memory,
|
memory=memory,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
elif model_mode == ModelMode.CHAT:
|
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||||
prompt_messages = self._get_chat_model_prompt_messages(
|
prompt_messages = self._get_chat_model_prompt_messages(
|
||||||
prompt_template=prompt_template,
|
prompt_template=prompt_template,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
query=query,
|
query=query,
|
||||||
query_prompt_template=query_prompt_template,
|
|
||||||
files=files,
|
files=files,
|
||||||
context=context,
|
context=context,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
|
@ -74,7 +72,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt_template: CompletionModelPromptTemplate,
|
prompt_template: CompletionModelPromptTemplate,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
files: list[FileVar],
|
files: Sequence[File],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
|
@ -88,10 +86,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
|
||||||
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
if prompt_template.edition_type == "basic" or not prompt_template.edition_type:
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
|
|
||||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||||
|
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
role_prefix = memory_config.role_prefix
|
role_prefix = memory_config.role_prefix
|
||||||
|
@ -100,15 +98,15 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
raw_prompt=raw_prompt,
|
raw_prompt=raw_prompt,
|
||||||
role_prefix=role_prefix,
|
role_prefix=role_prefix,
|
||||||
prompt_template=prompt_template,
|
parser=parser,
|
||||||
prompt_inputs=prompt_inputs,
|
prompt_inputs=prompt_inputs,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
prompt_inputs = self._set_query_variable(query, prompt_template, prompt_inputs)
|
prompt_inputs = self._set_query_variable(query, parser, prompt_inputs)
|
||||||
|
|
||||||
prompt = prompt_template.format(prompt_inputs)
|
prompt = parser.format(prompt_inputs)
|
||||||
else:
|
else:
|
||||||
prompt = raw_prompt
|
prompt = raw_prompt
|
||||||
prompt_inputs = inputs
|
prompt_inputs = inputs
|
||||||
|
@ -116,9 +114,10 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
||||||
|
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
|
@ -131,35 +130,28 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
prompt_template: list[ChatModelMessage],
|
prompt_template: list[ChatModelMessage],
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: Optional[str],
|
query: Optional[str],
|
||||||
files: list[FileVar],
|
files: Sequence[File],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory_config: Optional[MemoryConfig],
|
memory_config: Optional[MemoryConfig],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
query_prompt_template: Optional[str] = None,
|
|
||||||
) -> list[PromptMessage]:
|
) -> list[PromptMessage]:
|
||||||
"""
|
"""
|
||||||
Get chat model prompt messages.
|
Get chat model prompt messages.
|
||||||
"""
|
"""
|
||||||
raw_prompt_list = prompt_template
|
|
||||||
|
|
||||||
prompt_messages = []
|
prompt_messages = []
|
||||||
|
for prompt_item in prompt_template:
|
||||||
for prompt_item in raw_prompt_list:
|
|
||||||
raw_prompt = prompt_item.text
|
raw_prompt = prompt_item.text
|
||||||
|
|
||||||
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
|
prompt_inputs = self._set_context_variable(context=context, parser=parser, prompt_inputs=prompt_inputs)
|
||||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
prompt = parser.format(prompt_inputs)
|
||||||
|
|
||||||
prompt = prompt_template.format(prompt_inputs)
|
|
||||||
elif prompt_item.edition_type == "jinja2":
|
elif prompt_item.edition_type == "jinja2":
|
||||||
prompt = raw_prompt
|
prompt = raw_prompt
|
||||||
prompt_inputs = inputs
|
prompt_inputs = inputs
|
||||||
|
prompt = Jinja2Formatter.format(template=prompt, inputs=prompt_inputs)
|
||||||
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
raise ValueError(f"Invalid edition type: {prompt_item.edition_type}")
|
||||||
|
|
||||||
|
@ -170,25 +162,25 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
elif prompt_item.role == PromptMessageRole.ASSISTANT:
|
||||||
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
prompt_messages.append(AssistantPromptMessage(content=prompt))
|
||||||
|
|
||||||
if query and query_prompt_template:
|
if query and memory_config and memory_config.query_prompt_template:
|
||||||
prompt_template = PromptTemplateParser(
|
parser = PromptTemplateParser(
|
||||||
template=query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
template=memory_config.query_prompt_template, with_variable_tmpl=self.with_variable_tmpl
|
||||||
)
|
)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
prompt_inputs["#sys.query#"] = query
|
prompt_inputs["#sys.query#"] = query
|
||||||
|
|
||||||
prompt_inputs = self._set_context_variable(context, prompt_template, prompt_inputs)
|
prompt_inputs = self._set_context_variable(context, parser, prompt_inputs)
|
||||||
|
|
||||||
query = prompt_template.format(prompt_inputs)
|
query = parser.format(prompt_inputs)
|
||||||
|
|
||||||
if memory and memory_config:
|
if memory and memory_config:
|
||||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||||
|
|
||||||
if files:
|
if files and query is not None:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
prompt_messages.append(UserPromptMessage(content=query))
|
prompt_messages.append(UserPromptMessage(content=query))
|
||||||
|
@ -200,19 +192,19 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
# get last user message content and add files
|
# get last user message content and add files
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
prompt_message_contents = [TextPromptMessageContent(data=last_message.content)]
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
last_message.content = prompt_message_contents
|
last_message.content = prompt_message_contents
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
prompt_message_contents = [TextPromptMessageContent(data="")] # not for query
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
else:
|
else:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
prompt_message_contents = [TextPromptMessageContent(data=query)]
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||||
elif query:
|
elif query:
|
||||||
|
@ -220,8 +212,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
|
|
||||||
return prompt_messages
|
return prompt_messages
|
||||||
|
|
||||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
def _set_context_variable(self, context: str | None, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||||
if "#context#" in prompt_template.variable_keys:
|
if "#context#" in parser.variable_keys:
|
||||||
if context:
|
if context:
|
||||||
prompt_inputs["#context#"] = context
|
prompt_inputs["#context#"] = context
|
||||||
else:
|
else:
|
||||||
|
@ -229,8 +221,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
|
|
||||||
return prompt_inputs
|
return prompt_inputs
|
||||||
|
|
||||||
def _set_query_variable(self, query: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
def _set_query_variable(self, query: str, parser: PromptTemplateParser, prompt_inputs: dict) -> dict:
|
||||||
if "#query#" in prompt_template.variable_keys:
|
if "#query#" in parser.variable_keys:
|
||||||
if query:
|
if query:
|
||||||
prompt_inputs["#query#"] = query
|
prompt_inputs["#query#"] = query
|
||||||
else:
|
else:
|
||||||
|
@ -244,16 +236,16 @@ class AdvancedPromptTransform(PromptTransform):
|
||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
raw_prompt: str,
|
raw_prompt: str,
|
||||||
role_prefix: MemoryConfig.RolePrefix,
|
role_prefix: MemoryConfig.RolePrefix,
|
||||||
prompt_template: PromptTemplateParser,
|
parser: PromptTemplateParser,
|
||||||
prompt_inputs: dict,
|
prompt_inputs: dict,
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if "#histories#" in prompt_template.variable_keys:
|
if "#histories#" in parser.variable_keys:
|
||||||
if memory:
|
if memory:
|
||||||
inputs = {"#histories#": "", **prompt_inputs}
|
inputs = {"#histories#": "", **prompt_inputs}
|
||||||
prompt_template = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
parser = PromptTemplateParser(template=raw_prompt, with_variable_tmpl=self.with_variable_tmpl)
|
||||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||||
tmp_human_message = UserPromptMessage(content=prompt_template.format(prompt_inputs))
|
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||||
|
|
||||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||||
|
|
||||||
|
|
|
@ -5,9 +5,11 @@ from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import PromptTemplateEntity
|
from core.app.app_config.entities import PromptTemplateEntity
|
||||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||||
|
from core.file import file_manager
|
||||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
PromptMessage,
|
PromptMessage,
|
||||||
|
PromptMessageContent,
|
||||||
SystemPromptMessage,
|
SystemPromptMessage,
|
||||||
TextPromptMessageContent,
|
TextPromptMessageContent,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
|
@ -18,7 +20,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from models.model import AppMode
|
from models.model import AppMode
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
class ModelMode(enum.Enum):
|
class ModelMode(enum.Enum):
|
||||||
|
@ -53,7 +55,7 @@ class SimplePromptTransform(PromptTransform):
|
||||||
prompt_template_entity: PromptTemplateEntity,
|
prompt_template_entity: PromptTemplateEntity,
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
|
@ -169,7 +171,7 @@ class SimplePromptTransform(PromptTransform):
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
|
@ -214,7 +216,7 @@ class SimplePromptTransform(PromptTransform):
|
||||||
inputs: dict,
|
inputs: dict,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[str],
|
context: Optional[str],
|
||||||
files: list["FileVar"],
|
files: list["File"],
|
||||||
memory: Optional[TokenBufferMemory],
|
memory: Optional[TokenBufferMemory],
|
||||||
model_config: ModelConfigWithCredentialsEntity,
|
model_config: ModelConfigWithCredentialsEntity,
|
||||||
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
) -> tuple[list[PromptMessage], Optional[list[str]]]:
|
||||||
|
@ -261,11 +263,12 @@ class SimplePromptTransform(PromptTransform):
|
||||||
|
|
||||||
return [self.get_last_user_message(prompt, files)], stops
|
return [self.get_last_user_message(prompt, files)], stops
|
||||||
|
|
||||||
def get_last_user_message(self, prompt: str, files: list["FileVar"]) -> UserPromptMessage:
|
def get_last_user_message(self, prompt: str, files: list["File"]) -> UserPromptMessage:
|
||||||
if files:
|
if files:
|
||||||
prompt_message_contents = [TextPromptMessageContent(data=prompt)]
|
prompt_message_contents: list[PromptMessageContent] = []
|
||||||
|
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
|
||||||
for file in files:
|
for file in files:
|
||||||
prompt_message_contents.append(file.prompt_message_content)
|
prompt_message_contents.append(file_manager.to_prompt_message_content(file))
|
||||||
|
|
||||||
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
prompt_message = UserPromptMessage(content=prompt_message_contents)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -32,8 +32,8 @@ class UserToolProvider(BaseModel):
|
||||||
original_credentials: Optional[dict] = None
|
original_credentials: Optional[dict] = None
|
||||||
is_team_authorization: bool = False
|
is_team_authorization: bool = False
|
||||||
allow_delete: bool = True
|
allow_delete: bool = True
|
||||||
tools: list[UserTool] = None
|
tools: list[UserTool] | None = None
|
||||||
labels: list[str] = None
|
labels: list[str] | None = None
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
# -------------
|
# -------------
|
||||||
|
@ -42,7 +42,7 @@ class UserToolProvider(BaseModel):
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if tool.get("parameters"):
|
if tool.get("parameters"):
|
||||||
for parameter in tool.get("parameters"):
|
for parameter in tool.get("parameters"):
|
||||||
if parameter.get("type") == ToolParameter.ToolParameterType.FILE.value:
|
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
|
||||||
parameter["type"] = "files"
|
parameter["type"] = "files"
|
||||||
# -------------
|
# -------------
|
||||||
|
|
||||||
|
|
|
@ -104,14 +104,15 @@ class ToolInvokeMessage(BaseModel):
|
||||||
BLOB = "blob"
|
BLOB = "blob"
|
||||||
JSON = "json"
|
JSON = "json"
|
||||||
IMAGE_LINK = "image_link"
|
IMAGE_LINK = "image_link"
|
||||||
FILE_VAR = "file_var"
|
FILE = "file"
|
||||||
|
|
||||||
type: MessageType = MessageType.TEXT
|
type: MessageType = MessageType.TEXT
|
||||||
"""
|
"""
|
||||||
plain text, image url or link url
|
plain text, image url or link url
|
||||||
"""
|
"""
|
||||||
message: str | bytes | dict | None = None
|
message: str | bytes | dict | None = None
|
||||||
meta: dict[str, Any] | None = None
|
# TODO: Use a BaseModel for meta
|
||||||
|
meta: dict[str, Any] = Field(default_factory=dict)
|
||||||
save_as: str = ""
|
save_as: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@ -143,6 +144,67 @@ class ToolParameter(BaseModel):
|
||||||
SELECT = "select"
|
SELECT = "select"
|
||||||
SECRET_INPUT = "secret-input"
|
SECRET_INPUT = "secret-input"
|
||||||
FILE = "file"
|
FILE = "file"
|
||||||
|
FILES = "files"
|
||||||
|
|
||||||
|
# deprecated, should not use.
|
||||||
|
SYSTEM_FILES = "systme-files"
|
||||||
|
|
||||||
|
def as_normal_type(self):
|
||||||
|
if self in {
|
||||||
|
ToolParameter.ToolParameterType.SECRET_INPUT,
|
||||||
|
ToolParameter.ToolParameterType.SELECT,
|
||||||
|
}:
|
||||||
|
return "string"
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
def cast_value(self, value: Any, /):
|
||||||
|
try:
|
||||||
|
match self:
|
||||||
|
case (
|
||||||
|
ToolParameter.ToolParameterType.STRING
|
||||||
|
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
| ToolParameter.ToolParameterType.SELECT
|
||||||
|
):
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return value if isinstance(value, str) else str(value)
|
||||||
|
|
||||||
|
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
elif isinstance(value, str):
|
||||||
|
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||||
|
# and also '0' for False and '1' for True
|
||||||
|
match value.lower():
|
||||||
|
case "true" | "yes" | "y" | "1":
|
||||||
|
return True
|
||||||
|
case "false" | "no" | "n" | "0":
|
||||||
|
return False
|
||||||
|
case _:
|
||||||
|
return bool(value)
|
||||||
|
else:
|
||||||
|
return value if isinstance(value, bool) else bool(value)
|
||||||
|
|
||||||
|
case ToolParameter.ToolParameterType.NUMBER:
|
||||||
|
if isinstance(value, int | float):
|
||||||
|
return value
|
||||||
|
elif isinstance(value, str) and value:
|
||||||
|
if "." in value:
|
||||||
|
return float(value)
|
||||||
|
else:
|
||||||
|
return int(value)
|
||||||
|
case (
|
||||||
|
ToolParameter.ToolParameterType.SYSTEM_FILES
|
||||||
|
| ToolParameter.ToolParameterType.FILE
|
||||||
|
| ToolParameter.ToolParameterType.FILES
|
||||||
|
):
|
||||||
|
return value
|
||||||
|
case _:
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
|
||||||
|
|
||||||
class ToolParameterForm(Enum):
|
class ToolParameterForm(Enum):
|
||||||
SCHEMA = "schema" # should be set while adding tool
|
SCHEMA = "schema" # should be set while adding tool
|
||||||
|
|
|
@ -66,7 +66,7 @@ class DallE3Tool(BuiltinTool):
|
||||||
for image in response.data:
|
for image in response.data:
|
||||||
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
mime_type, blob_image = DallE3Tool._decode_image(image.b64_json)
|
||||||
blob_message = self.create_blob_message(
|
blob_message = self.create_blob_message(
|
||||||
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE.value
|
blob=blob_image, meta={"mime_type": mime_type}, save_as=self.VariableKey.IMAGE
|
||||||
)
|
)
|
||||||
result.append(blob_message)
|
result.append(blob_message)
|
||||||
return result
|
return result
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Any
|
||||||
|
|
||||||
from duckduckgo_search import DDGS
|
from duckduckgo_search import DDGS
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod
|
from core.file.models import FileTransferMethod
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,6 @@ from core.tools.errors import (
|
||||||
from core.tools.provider.tool_provider import ToolProviderController
|
from core.tools.provider.tool_provider import ToolProviderController
|
||||||
from core.tools.tool.builtin_tool import BuiltinTool
|
from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
from core.tools.utils.yaml_utils import load_yaml_file
|
from core.tools.utils.yaml_utils import load_yaml_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -208,9 +207,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||||
|
|
||||||
# the parameter is not set currently, set the default value if needed
|
# the parameter is not set currently, set the default value if needed
|
||||||
if parameter_schema.default is not None:
|
if parameter_schema.default is not None:
|
||||||
default_value = ToolParameterConverter.cast_parameter_by_type(
|
default_value = parameter_schema.type.cast_value(parameter_schema.default)
|
||||||
parameter_schema.default, parameter_schema.type
|
|
||||||
)
|
|
||||||
tool_parameters[parameter] = default_value
|
tool_parameters[parameter] = default_value
|
||||||
|
|
||||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||||
|
|
|
@ -11,7 +11,6 @@ from core.tools.entities.tool_entities import (
|
||||||
)
|
)
|
||||||
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderController(BaseModel, ABC):
|
class ToolProviderController(BaseModel, ABC):
|
||||||
|
@ -127,9 +126,7 @@ class ToolProviderController(BaseModel, ABC):
|
||||||
|
|
||||||
# the parameter is not set currently, set the default value if needed
|
# the parameter is not set currently, set the default value if needed
|
||||||
if parameter_schema.default is not None:
|
if parameter_schema.default is not None:
|
||||||
tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(
|
tool_parameters[parameter] = parameter_schema.type.cast_value(parameter_schema.default)
|
||||||
parameter_schema.default, parameter_schema.type
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
from core.app.app_config.entities import VariableEntityType
|
||||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||||
from core.tools.entities.common_entities import I18nObject
|
from core.tools.entities.common_entities import I18nObject
|
||||||
from core.tools.entities.tool_entities import (
|
from core.tools.entities.tool_entities import (
|
||||||
|
@ -23,6 +23,8 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
|
||||||
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
VariableEntityType.PARAGRAPH: ToolParameter.ToolParameterType.STRING,
|
||||||
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
VariableEntityType.SELECT: ToolParameter.ToolParameterType.SELECT,
|
||||||
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
VariableEntityType.NUMBER: ToolParameter.ToolParameterType.NUMBER,
|
||||||
|
VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE,
|
||||||
|
VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,8 +38,8 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError("app not found")
|
raise ValueError("app not found")
|
||||||
|
|
||||||
controller = WorkflowToolProviderController(
|
controller = WorkflowToolProviderController.model_validate(
|
||||||
**{
|
{
|
||||||
"identity": {
|
"identity": {
|
||||||
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
"author": db_provider.user.name if db_provider.user_id and db_provider.user else "",
|
||||||
"name": db_provider.label,
|
"name": db_provider.label,
|
||||||
|
@ -67,7 +69,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
:param app: the app
|
:param app: the app
|
||||||
:return: the tool
|
:return: the tool
|
||||||
"""
|
"""
|
||||||
workflow: Workflow = (
|
workflow = (
|
||||||
db.session.query(Workflow)
|
db.session.query(Workflow)
|
||||||
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
.filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
|
||||||
.first()
|
.first()
|
||||||
|
@ -76,14 +78,14 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
raise ValueError("workflow not found")
|
raise ValueError("workflow not found")
|
||||||
|
|
||||||
# fetch start node
|
# fetch start node
|
||||||
graph: dict = workflow.graph_dict
|
graph = workflow.graph_dict
|
||||||
features_dict: dict = workflow.features_dict
|
features_dict = workflow.features_dict
|
||||||
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
|
features = WorkflowAppConfigManager.convert_features(config_dict=features_dict, app_mode=AppMode.WORKFLOW)
|
||||||
|
|
||||||
parameters = db_provider.parameter_configurations
|
parameters = db_provider.parameter_configurations
|
||||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||||
|
|
||||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity:
|
def fetch_workflow_variable(variable_name: str):
|
||||||
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||||
|
|
||||||
user = db_provider.user
|
user = db_provider.user
|
||||||
|
@ -114,7 +116,6 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
llm_description=parameter.description,
|
llm_description=parameter.description,
|
||||||
required=variable.required,
|
required=variable.required,
|
||||||
options=options,
|
options=options,
|
||||||
default=variable.default,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif features.file_upload:
|
elif features.file_upload:
|
||||||
|
@ -123,7 +124,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
name=parameter.name,
|
name=parameter.name,
|
||||||
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
label=I18nObject(en_US=parameter.name, zh_Hans=parameter.name),
|
||||||
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
human_description=I18nObject(en_US=parameter.description, zh_Hans=parameter.description),
|
||||||
type=ToolParameter.ToolParameterType.FILE,
|
type=ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||||
llm_description=parameter.description,
|
llm_description=parameter.description,
|
||||||
required=False,
|
required=False,
|
||||||
form=parameter.form,
|
form=parameter.form,
|
||||||
|
|
|
@ -20,10 +20,9 @@ from core.tools.entities.tool_entities import (
|
||||||
ToolRuntimeVariablePool,
|
ToolRuntimeVariablePool,
|
||||||
)
|
)
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from core.file.file_obj import FileVar
|
from core.file.models import File
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseModel, ABC):
|
class Tool(BaseModel, ABC):
|
||||||
|
@ -63,8 +62,12 @@ class Tool(BaseModel, ABC):
|
||||||
def __init__(self, **data: Any):
|
def __init__(self, **data: Any):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|
||||||
class VariableKey(Enum):
|
class VariableKey(str, Enum):
|
||||||
IMAGE = "image"
|
IMAGE = "image"
|
||||||
|
DOCUMENT = "document"
|
||||||
|
VIDEO = "video"
|
||||||
|
AUDIO = "audio"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
|
||||||
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool":
|
||||||
"""
|
"""
|
||||||
|
@ -221,9 +224,7 @@ class Tool(BaseModel, ABC):
|
||||||
result = deepcopy(tool_parameters)
|
result = deepcopy(tool_parameters)
|
||||||
for parameter in self.parameters or []:
|
for parameter in self.parameters or []:
|
||||||
if parameter.name in tool_parameters:
|
if parameter.name in tool_parameters:
|
||||||
result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
|
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||||
tool_parameters[parameter.name], parameter.type
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -295,10 +296,8 @@ class Tool(BaseModel, ABC):
|
||||||
"""
|
"""
|
||||||
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, message=image, save_as=save_as)
|
||||||
|
|
||||||
def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
|
def create_file_message(self, file: "File") -> ToolInvokeMessage:
|
||||||
return ToolInvokeMessage(
|
return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.FILE, message="", meta={"file": file}, save_as="")
|
||||||
type=ToolInvokeMessage.MessageType.FILE_VAR, message="", meta={"file_var": file_var}, save_as=""
|
|
||||||
)
|
|
||||||
|
|
||||||
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
|
def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod, FileVar
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
@ -45,11 +45,13 @@ class WorkflowTool(Tool):
|
||||||
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
workflow = self._get_workflow(app_id=self.workflow_app_id, version=self.version)
|
||||||
|
|
||||||
# transform the tool parameters
|
# transform the tool parameters
|
||||||
tool_parameters, files = self._transform_args(tool_parameters)
|
tool_parameters, files = self._transform_args(tool_parameters=tool_parameters)
|
||||||
|
|
||||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
|
|
||||||
generator = WorkflowAppGenerator()
|
generator = WorkflowAppGenerator()
|
||||||
|
assert self.runtime is not None
|
||||||
|
assert self.runtime.invoke_from is not None
|
||||||
result = generator.generate(
|
result = generator.generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
|
@ -74,7 +76,7 @@ class WorkflowTool(Tool):
|
||||||
else:
|
else:
|
||||||
outputs, files = self._extract_files(outputs)
|
outputs, files = self._extract_files(outputs)
|
||||||
for file in files:
|
for file in files:
|
||||||
result.append(self.create_file_var_message(file))
|
result.append(self.create_file_message(file))
|
||||||
|
|
||||||
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||||
result.append(self.create_json_message(outputs))
|
result.append(self.create_json_message(outputs))
|
||||||
|
@ -154,22 +156,22 @@ class WorkflowTool(Tool):
|
||||||
parameters_result = {}
|
parameters_result = {}
|
||||||
files = []
|
files = []
|
||||||
for parameter in parameter_rules:
|
for parameter in parameter_rules:
|
||||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
if parameter.type == ToolParameter.ToolParameterType.SYSTEM_FILES:
|
||||||
file = tool_parameters.get(parameter.name)
|
file = tool_parameters.get(parameter.name)
|
||||||
if file:
|
if file:
|
||||||
try:
|
try:
|
||||||
file_var_list = [FileVar(**f) for f in file]
|
file_var_list = [File.model_validate(f) for f in file]
|
||||||
for file_var in file_var_list:
|
for file in file_var_list:
|
||||||
file_dict = {
|
file_dict: dict[str, str | None] = {
|
||||||
"transfer_method": file_var.transfer_method.value,
|
"transfer_method": file.transfer_method.value,
|
||||||
"type": file_var.type.value,
|
"type": file.type.value,
|
||||||
}
|
}
|
||||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
file_dict["tool_file_id"] = file_var.related_id
|
file_dict["tool_file_id"] = file.related_id
|
||||||
elif file_var.transfer_method == FileTransferMethod.LOCAL_FILE:
|
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
file_dict["upload_file_id"] = file_var.related_id
|
file_dict["upload_file_id"] = file.related_id
|
||||||
elif file_var.transfer_method == FileTransferMethod.REMOTE_URL:
|
elif file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
file_dict["url"] = file_var.preview_url
|
file_dict["url"] = file.generate_url()
|
||||||
|
|
||||||
files.append(file_dict)
|
files.append(file_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -179,7 +181,7 @@ class WorkflowTool(Tool):
|
||||||
|
|
||||||
return parameters_result, files
|
return parameters_result, files
|
||||||
|
|
||||||
def _extract_files(self, outputs: dict) -> tuple[dict, list[FileVar]]:
|
def _extract_files(self, outputs: dict) -> tuple[dict, list[File]]:
|
||||||
"""
|
"""
|
||||||
extract files from the result
|
extract files from the result
|
||||||
|
|
||||||
|
@ -190,17 +192,13 @@ class WorkflowTool(Tool):
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in outputs.items():
|
for key, value in outputs.items():
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
has_file = False
|
|
||||||
for item in value:
|
for item in value:
|
||||||
if isinstance(item, dict) and item.get("__variant") == "FileVar":
|
if isinstance(item, dict) and item.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||||
try:
|
file = File.model_validate(item)
|
||||||
files.append(FileVar(**item))
|
files.append(file)
|
||||||
has_file = True
|
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||||
except Exception as e:
|
file = File.model_validate(value)
|
||||||
pass
|
files.append(file)
|
||||||
if has_file:
|
|
||||||
continue
|
|
||||||
|
|
||||||
result[key] = value
|
result[key] = value
|
||||||
|
|
||||||
return result, files
|
return result, files
|
||||||
|
|
|
@ -10,7 +10,8 @@ from yarl import URL
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||||
from core.file.file_obj import FileTransferMethod
|
from core.file import FileType
|
||||||
|
from core.file.models import FileTransferMethod
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolInvokeMeta, ToolParameter
|
||||||
from core.tools.errors import (
|
from core.tools.errors import (
|
||||||
|
@ -25,6 +26,7 @@ from core.tools.errors import (
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool.workflow_tool import WorkflowTool
|
from core.tools.tool.workflow_tool import WorkflowTool
|
||||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||||
|
from enums import CreatedByRole
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.model import Message, MessageFile
|
from models.model import Message, MessageFile
|
||||||
|
|
||||||
|
@ -128,6 +130,7 @@ class ToolEngine:
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# hit the callback handler
|
# hit the callback handler
|
||||||
|
assert tool.identity is not None
|
||||||
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
workflow_tool_callback.on_tool_start(tool_name=tool.identity.name, tool_inputs=tool_parameters)
|
||||||
|
|
||||||
if isinstance(tool, WorkflowTool):
|
if isinstance(tool, WorkflowTool):
|
||||||
|
@ -258,7 +261,10 @@ class ToolEngine:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_message_files(
|
def _create_message_files(
|
||||||
tool_messages: list[ToolInvokeMessageBinary], agent_message: Message, invoke_from: InvokeFrom, user_id: str
|
tool_messages: list[ToolInvokeMessageBinary],
|
||||||
|
agent_message: Message,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
user_id: str,
|
||||||
) -> list[tuple[Any, str]]:
|
) -> list[tuple[Any, str]]:
|
||||||
"""
|
"""
|
||||||
Create message file
|
Create message file
|
||||||
|
@ -269,29 +275,31 @@ class ToolEngine:
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
for message in tool_messages:
|
for message in tool_messages:
|
||||||
file_type = "bin"
|
|
||||||
if "image" in message.mimetype:
|
if "image" in message.mimetype:
|
||||||
file_type = "image"
|
file_type = FileType.IMAGE
|
||||||
elif "video" in message.mimetype:
|
elif "video" in message.mimetype:
|
||||||
file_type = "video"
|
file_type = FileType.VIDEO
|
||||||
elif "audio" in message.mimetype:
|
elif "audio" in message.mimetype:
|
||||||
file_type = "audio"
|
file_type = FileType.AUDIO
|
||||||
elif "text" in message.mimetype:
|
elif "text" in message.mimetype or "pdf" in message.mimetype:
|
||||||
file_type = "text"
|
file_type = FileType.DOCUMENT
|
||||||
elif "pdf" in message.mimetype:
|
else:
|
||||||
file_type = "pdf"
|
file_type = FileType.CUSTOM
|
||||||
elif "zip" in message.mimetype:
|
|
||||||
file_type = "archive"
|
|
||||||
# ...
|
|
||||||
|
|
||||||
|
# extract tool file id from url
|
||||||
|
tool_file_id = message.url.split("/")[-1].split(".")[0]
|
||||||
message_file = MessageFile(
|
message_file = MessageFile(
|
||||||
message_id=agent_message.id,
|
message_id=agent_message.id,
|
||||||
type=file_type,
|
type=file_type,
|
||||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
belongs_to="assistant",
|
belongs_to="assistant",
|
||||||
url=message.url,
|
url=message.url,
|
||||||
upload_file_id=None,
|
upload_file_id=tool_file_id,
|
||||||
created_by_role=("account" if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user"),
|
created_by_role=(
|
||||||
|
CreatedByRole.ACCOUNT
|
||||||
|
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
|
else CreatedByRole.END_USER
|
||||||
|
),
|
||||||
created_by=user_id,
|
created_by=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -57,22 +57,32 @@ class ToolFileManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_file_by_raw(
|
def create_file_by_raw(
|
||||||
user_id: str, tenant_id: str, conversation_id: Optional[str], file_binary: bytes, mimetype: str
|
*,
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
file_binary: bytes,
|
||||||
|
mimetype: str,
|
||||||
) -> ToolFile:
|
) -> ToolFile:
|
||||||
"""
|
|
||||||
create file
|
|
||||||
"""
|
|
||||||
extension = guess_extension(mimetype) or ".bin"
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
unique_name = uuid4().hex
|
unique_name = uuid4().hex
|
||||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
filename = f"{unique_name}{extension}"
|
||||||
storage.save(filename, file_binary)
|
filepath = f"tools/{tenant_id}/{filename}"
|
||||||
|
storage.save(filepath, file_binary)
|
||||||
|
|
||||||
tool_file = ToolFile(
|
tool_file = ToolFile(
|
||||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=filename, mimetype=mimetype
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
file_key=filepath,
|
||||||
|
mimetype=mimetype,
|
||||||
|
name=filename,
|
||||||
|
size=len(file_binary),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(tool_file)
|
db.session.add(tool_file)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
db.session.refresh(tool_file)
|
||||||
|
|
||||||
return tool_file
|
return tool_file
|
||||||
|
|
||||||
|
@ -80,29 +90,34 @@ class ToolFileManager:
|
||||||
def create_file_by_url(
|
def create_file_by_url(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tenant_id: str,
|
tenant_id: str,
|
||||||
conversation_id: str,
|
conversation_id: str | None,
|
||||||
file_url: str,
|
file_url: str,
|
||||||
) -> ToolFile:
|
) -> ToolFile:
|
||||||
"""
|
|
||||||
create file
|
|
||||||
"""
|
|
||||||
# try to download image
|
# try to download image
|
||||||
response = get(file_url)
|
try:
|
||||||
response.raise_for_status()
|
response = get(file_url)
|
||||||
blob = response.content
|
response.raise_for_status()
|
||||||
|
blob = response.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to download file from {file_url}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||||
extension = guess_extension(mimetype) or ".bin"
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
unique_name = uuid4().hex
|
unique_name = uuid4().hex
|
||||||
filename = f"tools/{tenant_id}/{unique_name}{extension}"
|
filename = f"{unique_name}{extension}"
|
||||||
storage.save(filename, blob)
|
filepath = f"tools/{tenant_id}/{filename}"
|
||||||
|
storage.save(filepath, blob)
|
||||||
|
|
||||||
tool_file = ToolFile(
|
tool_file = ToolFile(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
file_key=filename,
|
file_key=filepath,
|
||||||
mimetype=mimetype,
|
mimetype=mimetype,
|
||||||
original_url=file_url,
|
original_url=file_url,
|
||||||
|
name=filename,
|
||||||
|
size=len(blob),
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(tool_file)
|
db.session.add(tool_file)
|
||||||
|
@ -110,18 +125,6 @@ class ToolFileManager:
|
||||||
|
|
||||||
return tool_file
|
return tool_file
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create_file_by_key(
|
|
||||||
user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str
|
|
||||||
) -> ToolFile:
|
|
||||||
"""
|
|
||||||
create file
|
|
||||||
"""
|
|
||||||
tool_file = ToolFile(
|
|
||||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype
|
|
||||||
)
|
|
||||||
return tool_file
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||||
"""
|
"""
|
||||||
|
@ -131,7 +134,7 @@ class ToolFileManager:
|
||||||
|
|
||||||
:return: the binary of the file, mime type
|
:return: the binary of the file, mime type
|
||||||
"""
|
"""
|
||||||
tool_file: ToolFile = (
|
tool_file = (
|
||||||
db.session.query(ToolFile)
|
db.session.query(ToolFile)
|
||||||
.filter(
|
.filter(
|
||||||
ToolFile.id == id,
|
ToolFile.id == id,
|
||||||
|
@ -155,7 +158,7 @@ class ToolFileManager:
|
||||||
|
|
||||||
:return: the binary of the file, mime type
|
:return: the binary of the file, mime type
|
||||||
"""
|
"""
|
||||||
message_file: MessageFile = (
|
message_file = (
|
||||||
db.session.query(MessageFile)
|
db.session.query(MessageFile)
|
||||||
.filter(
|
.filter(
|
||||||
MessageFile.id == id,
|
MessageFile.id == id,
|
||||||
|
@ -166,13 +169,16 @@ class ToolFileManager:
|
||||||
# Check if message_file is not None
|
# Check if message_file is not None
|
||||||
if message_file is not None:
|
if message_file is not None:
|
||||||
# get tool file id
|
# get tool file id
|
||||||
tool_file_id = message_file.url.split("/")[-1]
|
if message_file.url is not None:
|
||||||
# trim extension
|
tool_file_id = message_file.url.split("/")[-1]
|
||||||
tool_file_id = tool_file_id.split(".")[0]
|
# trim extension
|
||||||
|
tool_file_id = tool_file_id.split(".")[0]
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
else:
|
else:
|
||||||
tool_file_id = None
|
tool_file_id = None
|
||||||
|
|
||||||
tool_file: ToolFile = (
|
tool_file = (
|
||||||
db.session.query(ToolFile)
|
db.session.query(ToolFile)
|
||||||
.filter(
|
.filter(
|
||||||
ToolFile.id == tool_file_id,
|
ToolFile.id == tool_file_id,
|
||||||
|
@ -196,7 +202,7 @@ class ToolFileManager:
|
||||||
|
|
||||||
:return: the binary of the file, mime type
|
:return: the binary of the file, mime type
|
||||||
"""
|
"""
|
||||||
tool_file: ToolFile = (
|
tool_file = (
|
||||||
db.session.query(ToolFile)
|
db.session.query(ToolFile)
|
||||||
.filter(
|
.filter(
|
||||||
ToolFile.id == tool_file_id,
|
ToolFile.id == tool_file_id,
|
||||||
|
|
|
@ -24,7 +24,6 @@ from core.tools.tool.builtin_tool import BuiltinTool
|
||||||
from core.tools.tool.tool import Tool
|
from core.tools.tool.tool import Tool
|
||||||
from core.tools.tool_label_manager import ToolLabelManager
|
from core.tools.tool_label_manager import ToolLabelManager
|
||||||
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
|
||||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||||
from services.tools.tools_transform_service import ToolTransformService
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
@ -203,7 +202,7 @@ class ToolManager:
|
||||||
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
raise ToolProviderNotFoundError(f"provider type {provider_type} not found")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
|
def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
|
||||||
"""
|
"""
|
||||||
init runtime parameter
|
init runtime parameter
|
||||||
"""
|
"""
|
||||||
|
@ -222,7 +221,7 @@ class ToolManager:
|
||||||
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
|
return parameter_rule.type.cast_value(parameter_value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_agent_tool_runtime(
|
def get_agent_tool_runtime(
|
||||||
|
@ -243,7 +242,11 @@ class ToolManager:
|
||||||
parameters = tool_entity.get_all_runtime_parameters()
|
parameters = tool_entity.get_all_runtime_parameters()
|
||||||
for parameter in parameters:
|
for parameter in parameters:
|
||||||
# check file types
|
# check file types
|
||||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
if parameter.type in {
|
||||||
|
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||||
|
ToolParameter.ToolParameterType.FILE,
|
||||||
|
ToolParameter.ToolParameterType.FILES,
|
||||||
|
}:
|
||||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||||
|
|
||||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
from mimetypes import guess_extension
|
from mimetypes import guess_extension
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from core.file.file_obj import FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
from core.tools.tool_file_manager import ToolFileManager
|
from core.tools.tool_file_manager import ToolFileManager
|
||||||
|
|
||||||
|
@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||||
class ToolFileMessageTransformer:
|
class ToolFileMessageTransformer:
|
||||||
@classmethod
|
@classmethod
|
||||||
def transform_tool_invoke_messages(
|
def transform_tool_invoke_messages(
|
||||||
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
|
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str | None
|
||||||
) -> list[ToolInvokeMessage]:
|
) -> list[ToolInvokeMessage]:
|
||||||
"""
|
"""
|
||||||
Transform tool message and handle file download
|
Transform tool message and handle file download
|
||||||
|
@ -21,7 +22,7 @@ class ToolFileMessageTransformer:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK}:
|
||||||
result.append(message)
|
result.append(message)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.IMAGE:
|
elif message.type == ToolInvokeMessage.MessageType.IMAGE and isinstance(message.message, str):
|
||||||
# try to download image
|
# try to download image
|
||||||
try:
|
try:
|
||||||
file = ToolFileManager.create_file_by_url(
|
file = ToolFileManager.create_file_by_url(
|
||||||
|
@ -50,11 +51,14 @@ class ToolFileMessageTransformer:
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||||
# get mime type and save blob to storage
|
# get mime type and save blob to storage
|
||||||
|
assert message.meta is not None
|
||||||
mimetype = message.meta.get("mime_type", "octet/stream")
|
mimetype = message.meta.get("mime_type", "octet/stream")
|
||||||
# if message is str, encode it to bytes
|
# if message is str, encode it to bytes
|
||||||
if isinstance(message.message, str):
|
if isinstance(message.message, str):
|
||||||
message.message = message.message.encode("utf-8")
|
message.message = message.message.encode("utf-8")
|
||||||
|
|
||||||
|
# FIXME: should do a type check here.
|
||||||
|
assert isinstance(message.message, bytes)
|
||||||
file = ToolFileManager.create_file_by_raw(
|
file = ToolFileManager.create_file_by_raw(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
|
@ -63,7 +67,7 @@ class ToolFileMessageTransformer:
|
||||||
mimetype=mimetype,
|
mimetype=mimetype,
|
||||||
)
|
)
|
||||||
|
|
||||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||||
|
|
||||||
# check if file is image
|
# check if file is image
|
||||||
if "image" in mimetype:
|
if "image" in mimetype:
|
||||||
|
@ -84,12 +88,14 @@ class ToolFileMessageTransformer:
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||||
file_var = message.meta.get("file_var")
|
assert message.meta is not None
|
||||||
if file_var:
|
file = message.meta.get("file")
|
||||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
if isinstance(file, File):
|
||||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
if file_var.type == FileType.IMAGE:
|
assert file.related_id is not None
|
||||||
|
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
result.append(
|
result.append(
|
||||||
ToolInvokeMessage(
|
ToolInvokeMessage(
|
||||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||||
|
@ -107,11 +113,13 @@ class ToolFileMessageTransformer:
|
||||||
meta=message.meta.copy() if message.meta is not None else {},
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
result.append(message)
|
||||||
else:
|
else:
|
||||||
result.append(message)
|
result.append(message)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
|
def get_tool_file_url(cls, tool_file_id: str, extension: Optional[str]) -> str:
|
||||||
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
return f'/files/tools/{tool_file_id}{extension or ".bin"}'
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.tools.entities.tool_entities import ToolParameter
|
|
||||||
|
|
||||||
|
|
||||||
class ToolParameterConverter:
|
|
||||||
@staticmethod
|
|
||||||
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
|
|
||||||
match parameter_type:
|
|
||||||
case (
|
|
||||||
ToolParameter.ToolParameterType.STRING
|
|
||||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
|
||||||
| ToolParameter.ToolParameterType.SELECT
|
|
||||||
):
|
|
||||||
return "string"
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
|
||||||
return "boolean"
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.NUMBER:
|
|
||||||
return "number"
|
|
||||||
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unsupported parameter type {parameter_type}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def cast_parameter_by_type(value: Any, parameter_type: str) -> Any:
|
|
||||||
# convert tool parameter config to correct type
|
|
||||||
try:
|
|
||||||
match parameter_type:
|
|
||||||
case (
|
|
||||||
ToolParameter.ToolParameterType.STRING
|
|
||||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
|
||||||
| ToolParameter.ToolParameterType.SELECT
|
|
||||||
):
|
|
||||||
if value is None:
|
|
||||||
return ""
|
|
||||||
else:
|
|
||||||
return value if isinstance(value, str) else str(value)
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
|
||||||
if value is None:
|
|
||||||
return False
|
|
||||||
elif isinstance(value, str):
|
|
||||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
|
||||||
# and also '0' for False and '1' for True
|
|
||||||
match value.lower():
|
|
||||||
case "true" | "yes" | "y" | "1":
|
|
||||||
return True
|
|
||||||
case "false" | "no" | "n" | "0":
|
|
||||||
return False
|
|
||||||
case _:
|
|
||||||
return bool(value)
|
|
||||||
else:
|
|
||||||
return value if isinstance(value, bool) else bool(value)
|
|
||||||
|
|
||||||
case ToolParameter.ToolParameterType.NUMBER:
|
|
||||||
if isinstance(value, int) | isinstance(value, float):
|
|
||||||
return value
|
|
||||||
elif isinstance(value, str) and value != "":
|
|
||||||
if "." in value:
|
|
||||||
return float(value)
|
|
||||||
else:
|
|
||||||
return int(value)
|
|
||||||
case ToolParameter.ToolParameterType.FILE:
|
|
||||||
return value
|
|
||||||
case _:
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
raise ValueError(f"The tool parameter value {value} is not in correct type of {parameter_type}.")
|
|
|
@ -1,19 +1,18 @@
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity
|
||||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
|
|
||||||
|
|
||||||
class WorkflowToolConfigurationUtils:
|
class WorkflowToolConfigurationUtils:
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_parameter_configurations(cls, configurations: list[dict]):
|
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
|
||||||
"""
|
|
||||||
check parameter configurations
|
|
||||||
"""
|
|
||||||
for configuration in configurations:
|
for configuration in configurations:
|
||||||
if not WorkflowToolParameterConfiguration(**configuration):
|
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||||
raise ValueError("invalid parameter configuration")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||||
"""
|
"""
|
||||||
get workflow graph variables
|
get workflow graph variables
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
@ -17,15 +18,18 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
|
||||||
:param default_value: the value returned when errors ignored
|
:param default_value: the value returned when errors ignored
|
||||||
:return: an object of the YAML content
|
:return: an object of the YAML content
|
||||||
"""
|
"""
|
||||||
try:
|
if not file_path or not Path(file_path).exists():
|
||||||
with open(file_path, encoding="utf-8") as yaml_file:
|
|
||||||
try:
|
|
||||||
yaml_content = yaml.safe_load(yaml_file)
|
|
||||||
return yaml_content or default_value
|
|
||||||
except Exception as e:
|
|
||||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
if ignore_error:
|
if ignore_error:
|
||||||
return default_value
|
return default_value
|
||||||
else:
|
else:
|
||||||
raise e
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, encoding="utf-8") as yaml_file:
|
||||||
|
try:
|
||||||
|
yaml_content = yaml.safe_load(yaml_file)
|
||||||
|
return yaml_content or default_value
|
||||||
|
except Exception as e:
|
||||||
|
if ignore_error:
|
||||||
|
return default_value
|
||||||
|
else:
|
||||||
|
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
from .segment_group import SegmentGroup
|
from .segment_group import SegmentGroup
|
||||||
from .segments import (
|
from .segments import (
|
||||||
ArrayAnySegment,
|
ArrayAnySegment,
|
||||||
|
ArrayFileSegment,
|
||||||
|
ArrayNumberSegment,
|
||||||
|
ArrayObjectSegment,
|
||||||
ArraySegment,
|
ArraySegment,
|
||||||
|
ArrayStringSegment,
|
||||||
|
FileSegment,
|
||||||
FloatSegment,
|
FloatSegment,
|
||||||
IntegerSegment,
|
IntegerSegment,
|
||||||
NoneSegment,
|
NoneSegment,
|
||||||
|
@ -15,6 +20,7 @@ from .variables import (
|
||||||
ArrayNumberVariable,
|
ArrayNumberVariable,
|
||||||
ArrayObjectVariable,
|
ArrayObjectVariable,
|
||||||
ArrayStringVariable,
|
ArrayStringVariable,
|
||||||
|
FileVariable,
|
||||||
FloatVariable,
|
FloatVariable,
|
||||||
IntegerVariable,
|
IntegerVariable,
|
||||||
NoneVariable,
|
NoneVariable,
|
||||||
|
@ -46,4 +52,10 @@ __all__ = [
|
||||||
"ArrayNumberVariable",
|
"ArrayNumberVariable",
|
||||||
"ArrayObjectVariable",
|
"ArrayObjectVariable",
|
||||||
"ArraySegment",
|
"ArraySegment",
|
||||||
|
"ArrayFileSegment",
|
||||||
|
"ArrayNumberSegment",
|
||||||
|
"ArrayObjectSegment",
|
||||||
|
"ArrayStringSegment",
|
||||||
|
"FileSegment",
|
||||||
|
"FileVariable",
|
||||||
]
|
]
|
|
@ -5,6 +5,8 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, field_validator
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
|
||||||
|
from core.file import File
|
||||||
|
|
||||||
from .types import SegmentType
|
from .types import SegmentType
|
||||||
|
|
||||||
|
|
||||||
|
@ -39,6 +41,9 @@ class Segment(BaseModel):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
|
"""
|
||||||
|
Return the size of the value in bytes.
|
||||||
|
"""
|
||||||
return sys.getsizeof(self.value)
|
return sys.getsizeof(self.value)
|
||||||
|
|
||||||
def to_object(self) -> Any:
|
def to_object(self) -> Any:
|
||||||
|
@ -99,13 +104,27 @@ class ArraySegment(Segment):
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
items = []
|
items = []
|
||||||
for item in self.value:
|
for item in self.value:
|
||||||
if hasattr(item, "to_markdown"):
|
items.append(str(item))
|
||||||
items.append(item.to_markdown())
|
|
||||||
else:
|
|
||||||
items.append(str(item))
|
|
||||||
return "\n".join(items)
|
return "\n".join(items)
|
||||||
|
|
||||||
|
|
||||||
|
class FileSegment(Segment):
|
||||||
|
value_type: SegmentType = SegmentType.FILE
|
||||||
|
value: File
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
return self.value.markdown
|
||||||
|
|
||||||
|
@property
|
||||||
|
def log(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def text(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||||
value: Sequence[Any]
|
value: Sequence[Any]
|
||||||
|
@ -124,3 +143,15 @@ class ArrayNumberSegment(ArraySegment):
|
||||||
class ArrayObjectSegment(ArraySegment):
|
class ArrayObjectSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||||
value: Sequence[Mapping[str, Any]]
|
value: Sequence[Mapping[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayFileSegment(ArraySegment):
|
||||||
|
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||||
|
value: Sequence[File]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def markdown(self) -> str:
|
||||||
|
items = []
|
||||||
|
for item in self.value:
|
||||||
|
items.append(item.markdown)
|
||||||
|
return "\n".join(items)
|
|
@ -11,5 +11,7 @@ class SegmentType(str, Enum):
|
||||||
ARRAY_NUMBER = "array[number]"
|
ARRAY_NUMBER = "array[number]"
|
||||||
ARRAY_OBJECT = "array[object]"
|
ARRAY_OBJECT = "array[object]"
|
||||||
OBJECT = "object"
|
OBJECT = "object"
|
||||||
|
FILE = "file"
|
||||||
|
ARRAY_FILE = "array[file]"
|
||||||
|
|
||||||
GROUP = "group"
|
GROUP = "group"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user