diff --git a/api/app.py b/api/app.py index 5e1c814e1c..a3efabf06c 100644 --- a/api/app.py +++ b/api/app.py @@ -10,44 +10,20 @@ if os.environ.get("DEBUG", "false").lower() != "true": grpc.experimental.gevent.init_gevent() import json -import logging -import sys import threading import time import warnings -from logging.handlers import RotatingFileHandler -from flask import Flask, Response, request -from flask_cors import CORS -from werkzeug.exceptions import Unauthorized +from flask import Response -import contexts -from commands import register_commands -from configs import dify_config +from app_factory import create_app # DO NOT REMOVE BELOW from events import event_handlers # noqa: F401 -from extensions import ( - ext_celery, - ext_code_based_extension, - ext_compress, - ext_database, - ext_hosting_provider, - ext_login, - ext_mail, - ext_migrate, - ext_proxy_fix, - ext_redis, - ext_sentry, - ext_storage, -) from extensions.ext_database import db -from extensions.ext_login import login_manager -from libs.passport import PassportService # TODO: Find a way to avoid importing models here from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 -from services.account_service import AccountService # DO NOT REMOVE ABOVE @@ -60,189 +36,12 @@ if hasattr(time, "tzset"): time.tzset() -class DifyApp(Flask): - pass - - # ------------- # Configuration # ------------- - - config_type = os.getenv("EDITION", default="SELF_HOSTED") # ce edition first -# ---------------------------- -# Application Factory Function -# ---------------------------- - - -def create_flask_app_with_configs() -> Flask: - """ - create a raw flask app - with configs loaded from .env file - """ - dify_app = DifyApp(__name__) - dify_app.config.from_mapping(dify_config.model_dump()) - - # populate configs into system environment variables - for key, value in dify_app.config.items(): - if isinstance(value, str): - os.environ[key] = value - elif isinstance(value, int | float | bool): - os.environ[key] = str(value) - elif value is None: - os.environ[key] = "" - - return dify_app - - -def create_app() -> Flask: - app = create_flask_app_with_configs() - - app.secret_key = app.config["SECRET_KEY"] - - log_handlers = None - log_file = app.config.get("LOG_FILE") - if log_file: - log_dir = os.path.dirname(log_file) - os.makedirs(log_dir, exist_ok=True) - log_handlers = [ - RotatingFileHandler( - filename=log_file, - maxBytes=1024 * 1024 * 1024, - backupCount=5, - ), - logging.StreamHandler(sys.stdout), - ] - - logging.basicConfig( - level=app.config.get("LOG_LEVEL"), - format=app.config["LOG_FORMAT"], - datefmt=app.config.get("LOG_DATEFORMAT"), - handlers=log_handlers, - force=True, - ) - log_tz = app.config.get("LOG_TZ") - if log_tz: - from datetime import datetime - - import pytz - - timezone = pytz.timezone(log_tz) - - def time_converter(seconds): - return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() - - for handler in logging.root.handlers: - assert handler.formatter - handler.formatter.converter = time_converter - initialize_extensions(app) - register_blueprints(app) - register_commands(app) - - return app - - -def initialize_extensions(app): - # Since the application instance is now created, pass it to each Flask - # extension instance to bind it to the Flask application instance (app) - ext_compress.init_app(app) - ext_code_based_extension.init() - ext_database.init_app(app) - ext_migrate.init(app, db) - ext_redis.init_app(app) - ext_storage.init_app(app) - ext_celery.init_app(app) - ext_login.init_app(app) - ext_mail.init_app(app) - ext_hosting_provider.init_app(app) - ext_sentry.init_app(app) - ext_proxy_fix.init_app(app) - - -# Flask-Login configuration -@login_manager.request_loader -def load_user_from_request(request_from_flask_login): - """Load user based on the request.""" - if request.blueprint not in {"console", "inner_api"}: - return None - # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") - - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) - if logged_in_account: - contexts.tenant_id.set(logged_in_account.current_tenant_id) - return logged_in_account - - -@login_manager.unauthorized_handler -def unauthorized_handler(): - """Handle unauthorized requests.""" - return Response( - json.dumps({"code": "unauthorized", "message": "Unauthorized."}), - status=401, - content_type="application/json", - ) - - -# register blueprint routers -def register_blueprints(app): - from controllers.console import bp as console_app_bp - from controllers.files import bp as files_bp - from controllers.inner_api import bp as inner_api_bp - from controllers.service_api import bp as service_api_bp - from controllers.web import bp as web_bp - - CORS( - service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - ) - app.register_blueprint(service_api_bp) - - CORS( - web_bp, - resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(web_bp) - - CORS( - console_app_bp, - resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(console_app_bp) - - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) - app.register_blueprint(files_bp) - - app.register_blueprint(inner_api_bp) - - # create app app = create_app() celery = app.extensions["celery"] diff --git a/api/app_factory.py b/api/app_factory.py new file mode 100644 index 0000000000..04654c2699 --- /dev/null +++ b/api/app_factory.py @@ -0,0 +1,213 @@ +import os + +if os.environ.get("DEBUG", "false").lower() != "true": + from gevent import monkey + + monkey.patch_all() + + import grpc.experimental.gevent + + grpc.experimental.gevent.init_gevent() + +import json +import logging +import sys +from logging.handlers import RotatingFileHandler + +from flask import Flask, Response, request +from flask_cors import CORS +from werkzeug.exceptions import Unauthorized + +import contexts +from commands import register_commands +from configs import dify_config +from extensions import ( + ext_celery, + ext_code_based_extension, + ext_compress, + ext_database, + ext_hosting_provider, + ext_login, + ext_mail, + ext_migrate, + ext_proxy_fix, + ext_redis, + ext_sentry, + ext_storage, +) +from extensions.ext_database import db +from extensions.ext_login import login_manager +from libs.passport import PassportService +from services.account_service import AccountService + + +class DifyApp(Flask): + pass + + +# ---------------------------- +# Application Factory Function +# ---------------------------- +def create_flask_app_with_configs() -> Flask: + """ + create a raw flask app + with configs loaded from .env file + """ + dify_app = DifyApp(__name__) + dify_app.config.from_mapping(dify_config.model_dump()) + + # populate configs into system environment variables + for key, value in dify_app.config.items(): + if isinstance(value, str): + os.environ[key] = value + elif isinstance(value, int | float | bool): + os.environ[key] = str(value) + elif value is None: + os.environ[key] = "" + + return dify_app + + +def create_app() -> Flask: + app = create_flask_app_with_configs() + + app.secret_key = app.config["SECRET_KEY"] + + log_handlers = None + log_file = app.config.get("LOG_FILE") + if log_file: + log_dir = os.path.dirname(log_file) + os.makedirs(log_dir, exist_ok=True) + log_handlers = [ + RotatingFileHandler( + filename=log_file, + maxBytes=1024 * 1024 * 1024, + backupCount=5, + ), + logging.StreamHandler(sys.stdout), + ] + + logging.basicConfig( + level=app.config.get("LOG_LEVEL"), + format=app.config.get("LOG_FORMAT"), + datefmt=app.config.get("LOG_DATEFORMAT"), + handlers=log_handlers, + force=True, + ) + log_tz = app.config.get("LOG_TZ") + if log_tz: + from datetime import datetime + + import pytz + + timezone = pytz.timezone(log_tz) + + def time_converter(seconds): + return datetime.utcfromtimestamp(seconds).astimezone(timezone).timetuple() + + for handler in logging.root.handlers: + handler.formatter.converter = time_converter + initialize_extensions(app) + register_blueprints(app) + register_commands(app) + + return app + + +def initialize_extensions(app): + # Since the application instance is now created, pass it to each Flask + # extension instance to bind it to the Flask application instance (app) + ext_compress.init_app(app) + ext_code_based_extension.init() + ext_database.init_app(app) + ext_migrate.init(app, db) + ext_redis.init_app(app) + ext_storage.init_app(app) + ext_celery.init_app(app) + ext_login.init_app(app) + ext_mail.init_app(app) + ext_hosting_provider.init_app(app) + ext_sentry.init_app(app) + ext_proxy_fix.init_app(app) + + +# Flask-Login configuration +@login_manager.request_loader +def load_user_from_request(request_from_flask_login): + """Load user based on the request.""" + if request.blueprint not in {"console", "inner_api"}: + return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get("Authorization", "") + if not auth_header: + auth_token = request.args.get("_token") + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + else: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + if logged_in_account: + contexts.tenant_id.set(logged_in_account.current_tenant_id) + return logged_in_account + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) + + +# register blueprint routers +def register_blueprints(app): + from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) + app.register_blueprint(service_api_bp) + + CORS( + web_bp, + resources={r"/*": {"origins": app.config["WEB_API_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(web_bp) + + CORS( + console_app_bp, + resources={r"/*": {"origins": app.config["CONSOLE_CORS_ALLOW_ORIGINS"]}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(console_app_bp) + + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + app.register_blueprint(files_bp) + + app.register_blueprint(inner_api_bp) diff --git a/api/commands.py b/api/commands.py index 1b1f376574..f2809be8e7 100644 --- a/api/commands.py +++ b/api/commands.py @@ -259,6 +259,25 @@ def migrate_knowledge_vector_database(): skipped_count = 0 total_count = 0 vector_type = dify_config.VECTOR_STORE + upper_colletion_vector_types = { + VectorType.MILVUS, + VectorType.PGVECTOR, + VectorType.RELYT, + VectorType.WEAVIATE, + VectorType.ORACLE, + VectorType.ELASTICSEARCH, + } + lower_colletion_vector_types = { + VectorType.ANALYTICDB, + VectorType.CHROMA, + VectorType.MYSCALE, + VectorType.PGVECTO_RS, + VectorType.TIDB_VECTOR, + VectorType.OPENSEARCH, + VectorType.TENCENT, + VectorType.BAIDU, + VectorType.VIKINGDB, + } page = 1 while True: try: @@ -284,11 +303,9 @@ def migrate_knowledge_vector_database(): skipped_count = skipped_count + 1 continue collection_name = "" - if vector_type == VectorType.WEAVIATE: - dataset_id = dataset.id + dataset_id = dataset.id + if vector_type in upper_colletion_vector_types: collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.WEAVIATE, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) elif vector_type == VectorType.QDRANT: if dataset.collection_binding_id: dataset_collection_binding = ( @@ -301,63 +318,15 @@ def migrate_knowledge_vector_database(): else: raise ValueError("Dataset Collection Binding not found") else: - dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.QDRANT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.MILVUS: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.MILVUS, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.RELYT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "relyt", "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.TENCENT: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.TENCENT, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.PGVECTOR: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": VectorType.PGVECTOR, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.OPENSEARCH: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.OPENSEARCH, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ANALYTICDB: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.ANALYTICDB, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.ELASTICSEARCH: - dataset_id = dataset.id - index_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} - dataset.index_struct = json.dumps(index_struct_dict) - elif vector_type == VectorType.BAIDU: - dataset_id = dataset.id - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - index_struct_dict = { - "type": VectorType.BAIDU, - "vector_store": {"class_prefix": collection_name}, - } - dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type in lower_colletion_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() else: raise ValueError(f"Vector store {vector_type} is not supported.") + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} + dataset.index_struct = json.dumps(index_struct_dict) vector = Vector(dataset) click.echo(f"Migrating dataset {dataset.id}.") diff --git a/api/core/embedding/embedding_constant.py b/api/core/entities/embedding_type.py similarity index 100% rename from api/core/embedding/embedding_constant.py rename to api/core/entities/embedding_type.py diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 183fe4a0dd..e21449ec24 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -3,7 +3,7 @@ import os from collections.abc import Callable, Generator, Iterable, Sequence from typing import IO, Any, Optional, Union, cast -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index a948dca20d..2d38fba955 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import ConfigDict -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 8701a38050..c45ce87ea7 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np import tiktoken from openai import AzureOpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 56b9be1c36..1ace68d2b9 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index d9c5726592..2f998d8bda 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -13,7 +13,7 @@ from botocore.exceptions import ( UnknownServiceError, ) -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index 4da2080690..5fd4d637be 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ import cohere import numpy as np from cohere.core import RequestOptions -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py index cdce69ff38..c745a7e978 100644 --- a/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/fireworks/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional, Union import numpy as np from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/groq/groq.yaml b/api/core/model_runtime/model_providers/groq/groq.yaml index db17cc8bdd..d6534e1bf1 100644 --- a/api/core/model_runtime/model_providers/groq/groq.yaml +++ b/api/core/model_runtime/model_providers/groq/groq.yaml @@ -18,6 +18,7 @@ help: en_US: https://console.groq.com/ supported_model_types: - llm + - speech2text configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml new file mode 100644 index 0000000000..5632218797 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-11b-vision-preview.yaml @@ -0,0 +1,26 @@ +model: llama-3.2-11b-vision-preview +label: + zh_Hans: Llama 3.2 11B Vision (Preview) + en_US: Llama 3.2 11B Vision (Preview) +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml new file mode 100644 index 0000000000..e7b93101e8 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama-3.2-90b-vision-preview.yaml @@ -0,0 +1,26 @@ +model: llama-3.2-90b-vision-preview +label: + zh_Hans: Llama 3.2 90B Vision (Preview) + en_US: Llama 3.2 90B Vision (Preview) +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.05' + output: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/speech2text/__init__.py b/api/core/model_runtime/model_providers/groq/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml b/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml new file mode 100644 index 0000000000..202d006a66 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/distil-whisper-large-v3-en.yaml @@ -0,0 +1,5 @@ +model: distil-whisper-large-v3-en +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py b/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py new file mode 100644 index 0000000000..75feeb9cb9 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/speech2text.py @@ -0,0 +1,30 @@ +from typing import IO, Optional + +from core.model_runtime.model_providers.openai_api_compatible.speech2text.speech2text import OAICompatSpeech2TextModel + + +class GroqSpeech2TextModel(OAICompatSpeech2TextModel): + """ + Model class for Groq Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, file: IO[bytes], user: Optional[str] = None) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + return super().validate_credentials(model, credentials) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials["endpoint_url"] = "https://api.groq.com/openai/v1" diff --git a/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml new file mode 100644 index 0000000000..3882a3f4f2 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3-turbo.yaml @@ -0,0 +1,5 @@ +model: whisper-large-v3-turbo +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml new file mode 100644 index 0000000000..ed02477d70 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/speech2text/whisper-large-v3.yaml @@ -0,0 +1,5 @@ +model: whisper-large-v3 +model_type: speech2text +model_properties: + file_upload_limit: 1 + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index b2e6d1b652..8278d1e64d 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import requests from huggingface_hub import HfApi, InferenceClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index b8ff3ca549..6b43934538 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import time from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py index 75701ebc54..b6d857cb37 100644 --- a/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/hunyuan/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from tencentcloud.common.profile.client_profile import ClientProfile from tencentcloud.common.profile.http_profile import HttpProfile from tencentcloud.hunyuan.v20230901 import hunyuan_client, models -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index b397129512..49c558f4a4 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index ab8ca76c2f..b4dfc1a4de 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from yarl import URL -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index d031bfa04d..29be5888af 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py index 68b7b448bf..ca949cb953 100644 --- a/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/mixedbread/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py index 857dfb5f41..56a707333c 100644 --- a/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nomic/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from nomic import embed from nomic import login as nomic_login -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index 936ceb8dd2..04363e11be 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from requests import post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py index 4de9296cca..50fa63768c 100644 --- a/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/oci/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Optional import numpy as np import oci -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 5cf3f1c6fa..a16c91cd7e 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -8,7 +8,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 16f1a0cfa1..bec01fe679 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ import numpy as np import tiktoken from openai import OpenAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 64fa6aaa3c..c2b7297aac 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index c5d4330912..43a2e948e2 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -5,7 +5,7 @@ from typing import Optional from requests import post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( diff --git a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml index 7a1dea6950..6743bfcad6 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-chat.yaml @@ -35,6 +35,15 @@ parameter_rules: help: zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty default: 0 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml index c05f4769b8..375a4d2d52 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/deepseek-coder.yaml @@ -18,6 +18,15 @@ parameter_rules: min: 0 max: 1 default: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens min: 1 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml index 186c1cc663..621ecf065e 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-3.5-turbo.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml index 8c2989b300..887e6d60f9 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4-32k.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml index ef19d4f6f0..66d1f9ae67 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4.yaml @@ -14,6 +14,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml index 0be325f55b..695cc3eedf 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-2024-08-06.yaml @@ -16,6 +16,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml index 3b1d95643d..e1e5889085 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o-mini.yaml @@ -15,6 +15,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml index a8c97efdd6..560bf9d7d0 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/gpt-4o.yaml @@ -15,6 +15,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml index b91c39e729..04a4a90c6d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-70b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml index 84b2c7fac2..066949d431 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3-8b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml index a489ce1b5a..0cd89dea71 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-405b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml index 12037411b1..768ab5ecbb 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-70b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml index 6f06493f29..67b6b82b5d 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/llama-3.1-8b-instruct.yaml @@ -10,6 +10,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens required: true diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml index 012dfc55ce..d08c016e95 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mistral-7b-instruct.yaml @@ -18,6 +18,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml index f4eb4e45d9..e3af0e64d8 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x22b-instruct.yaml @@ -18,6 +18,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml index 7871e1f7a0..095ea5a858 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/mixtral-8x7b-instruct.yaml @@ -19,6 +19,15 @@ parameter_rules: default: 1 min: 0 max: 1 + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: max_tokens use_template: max_tokens default: 1024 diff --git a/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml b/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml index 85a918ff5e..f4202ee814 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/o1-mini.yaml @@ -12,6 +12,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml b/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml index 74b0a511be..1281b84286 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/o1-preview.yaml @@ -12,6 +12,15 @@ parameter_rules: use_template: temperature - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: presence_penalty use_template: presence_penalty - name: frequency_penalty diff --git a/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml index 7b75fcb0c9..b6058138d3 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/qwen2-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml index f141a40a00..5392b11168 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/openrouter/llm/qwen2.5-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index 1e86f351c8..d78bdaa75e 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from urllib.parse import urljoin import numpy as np import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 9f724a77ac..c4e9d0b9c6 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional from replicate import Client as ReplicateClient -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py index 6aa8c9995f..94bae71e53 100644 --- a/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/sagemaker/speech2text/speech2text.py @@ -14,6 +14,7 @@ from core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) +from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url @@ -77,7 +78,8 @@ class SageMakerSpeech2TextModel(Speech2TextModel): json_obj = json.loads(json_str) asr_text = json_obj["text"] except Exception as e: - logger.exception(f"Exception {e}, line : {line}") + logger.exception(f"failed to invoke speech2text model, {e}") + raise CredentialsValidateFailedError(str(e)) return asr_text diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index 8f993ce672..ae7d805b4e 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -6,7 +6,7 @@ from typing import Any, Optional import boto3 -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml index d4431179e5..d5f23776ea 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepdeek-coder-v2-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml index caa6508b5e..7aa684ef38 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml index 1c8e15ae52..b30fa3e2d1 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/deepseek-v2.5.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml index 2840e3dcf4..f2a1f64bfb 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-27b-it.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml index d7e19b46f6..b096b9b647 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/gemma-2-9b-it.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml index 9b32a02477..87acc557b7 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/glm4-9b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml index d9663582e5..60157c2b46 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-20b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml index 73ad4480aa..faf4af7ea3 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/internlm2_5-7b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml index 9993d781ac..d01770cb01 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-70b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml index 60e3764789..3cd75d89e8 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3-8b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml index f992660aa2..3506a70bcc 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-405b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml index 1c69d63a40..994a754a82 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-70b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml index a97002a5ca..ebfa9aac9d 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/meta-mlama-3.1-8b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml index 89fb153ba0..a71d8688a8 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-7b-instruct-v0.2.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml index 2785e7496f..db45a75c6d 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/mistral-8x7b-instruct-v0.1.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml index f6c976af8e..bec5d37c57 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-1.5b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml index a996e919ea..b2461335f8 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-57b-a14b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml index a6e2c22dac..e0f23bd89e 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml index d8bea5e129..47a9da8119 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2-7b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml index 02a401464b..9cc5ac4c91 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-14b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml index d084617e7d..c7fb21e9e1 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-32b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml index dfbad2494c..03136c88a1 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-72b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml index cdc8ffc4d2..99412adde7 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/qwen2.5-7b-instruct.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml index 864ba46f1a..3e25f82369 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-34b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml index fe4c8b4b3e..827b2ce1e5 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-6b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml index c61f0dc53f..112fcbfe97 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/yi-1.5-9b-chat.yaml @@ -21,6 +21,15 @@ parameter_rules: en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. - name: top_p use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false - name: frequency_penalty use_template: frequency_penalty pricing: diff --git a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py index c5dcc12610..5e29a4827a 100644 --- a/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/siliconflow/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( OAICompatEmbeddingModel, diff --git a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py index 736cd44df8..2ef7f3f577 100644 --- a/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/tongyi/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import dashscope import numpy as np -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import ( EmbeddingUsage, diff --git a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py index b6509cd26c..7dd495b55e 100644 --- a/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/upstage/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ import numpy as np from openai import OpenAI from tokenizers import Tokenizer -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index fce9544df0..43233e6126 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -9,7 +9,7 @@ from google.cloud import aiplatform from google.oauth2 import service_account from vertexai.language_models import TextEmbeddingModel as VertexTextEmbeddingModel -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py index 0dd4037c95..4d13e4708b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/text_embedding.py @@ -2,7 +2,7 @@ import time from decimal import Decimal from typing import Optional -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ( AIModelEntity, diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py index a8a4d3c15b..e69c9fccba 100644 --- a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -4,7 +4,7 @@ from typing import Optional import requests -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/wenxin/_common.py b/api/core/model_runtime/model_providers/wenxin/_common.py index d72d1bd83a..1a4cc15371 100644 --- a/api/core/model_runtime/model_providers/wenxin/_common.py +++ b/api/core/model_runtime/model_providers/wenxin/_common.py @@ -120,6 +120,7 @@ class _CommonWenxin: "bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en", "bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh", "tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k", + "bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base", } function_calling_supports = [ diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py b/api/core/model_runtime/model_providers/wenxin/rerank/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml new file mode 100644 index 0000000000..ef4b07d767 --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/bce-reranker-base_v1.yaml @@ -0,0 +1,8 @@ +model: bce-reranker-base_v1 +model_type: rerank +model_properties: + context_size: 4096 +pricing: + input: '0.0005' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py new file mode 100644 index 0000000000..b22aead22b --- /dev/null +++ b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py @@ -0,0 +1,147 @@ +from typing import Optional + +import httpx + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.wenxin._common import _CommonWenxin + + +class WenxinRerank(_CommonWenxin): + def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None): + access_token = self._get_access_token() + url = f"{self.api_bases[model]}?access_token={access_token}" + + try: + response = httpx.post( + url, + json={"model": model, "query": query, "documents": docs, "top_n": top_n}, + headers={"Content-Type": "application/json"}, + ) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + +class WenxinRerankModel(RerankModel): + """ + Model class for wenxin rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + api_key = credentials["api_key"] + secret_key = credentials["secret_key"] + + wenxin_rerank: WenxinRerank = WenxinRerank(api_key, secret_key) + + try: + results = wenxin_rerank.rerank(model, query, docs, top_n) + + rerank_documents = [] + for result in results["results"]: + index = result["index"] + if "document" in result: + text = result["document"] + else: + # llama.cpp rerank maynot return original documents + text = docs[index] + + rerank_document = RerankDocument( + index=index, + text=text, + score=result["relevance_score"], + ) + + if score_threshold is None or result["relevance_score"] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError], + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + """ + generate custom model entities from credentials + """ + entity = AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=ModelType.RERANK, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + ) + + return entity diff --git a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py index c21d0c0552..19135deb27 100644 --- a/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/wenxin/text_embedding/text_embedding.py @@ -7,7 +7,7 @@ from typing import Any, Optional import numpy as np from requests import Response, post -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import InvokeError diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml index 6a6b38e6a1..d8acfd8120 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.yaml +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.yaml @@ -18,6 +18,7 @@ help: supported_model_types: - llm - text-embedding + - rerank configurate_methods: - predefined-model provider_credential_schema: diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index ddc21b365c..f64b9c50af 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult diff --git a/api/core/model_runtime/model_providers/yi/llm/_position.yaml b/api/core/model_runtime/model_providers/yi/llm/_position.yaml index e876893b41..5fa098beda 100644 --- a/api/core/model_runtime/model_providers/yi/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/yi/llm/_position.yaml @@ -7,3 +7,4 @@ - yi-medium-200k - yi-spark - yi-large-turbo +- yi-lightning diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index 5ab7fd126e..0642e72ed5 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -4,12 +4,22 @@ from urllib.parse import urlparse import tiktoken -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, SystemPromptMessage, ) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel @@ -125,3 +135,58 @@ class YiLargeLanguageModel(OpenAILargeLanguageModel): else: parsed_url = urlparse(credentials["endpoint_url"]) credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 8192)), + label=I18nObject( + en_US="Max Tokens", zh_Hans="指定生成结果长度的上限。如果生成结果截断,可以调大该参数" + ), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject( + en_US="Top P", + zh_Hans="控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。", + ), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="取样数量"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="frequency_penalty", + use_template="frequency_penalty", + label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"), + type=ParameterType.FLOAT, + ), + ], + ) diff --git a/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml b/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml new file mode 100644 index 0000000000..fccf1b3a26 --- /dev/null +++ b/api/core/model_runtime/model_providers/yi/llm/yi-lightning.yaml @@ -0,0 +1,43 @@ +model: yi-lightning +label: + zh_Hans: yi-lightning + en_US: yi-lightning +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 16384 +parameter_rules: + - name: temperature + use_template: temperature + type: float + default: 0.3 + min: 0.0 + max: 2.0 + help: + zh_Hans: 控制生成结果的多样性和随机性。数值越小,越严谨;数值越大,越发散。 + en_US: Control the diversity and randomness of generated results. The smaller the value, the more rigorous it is; the larger the value, the more divergent it is. + - name: max_tokens + use_template: max_tokens + type: int + default: 1024 + min: 1 + max: 4000 + help: + zh_Hans: 指定生成结果长度的上限。如果生成结果截断,可以调大该参数。 + en_US: Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter. + - name: top_p + use_template: top_p + type: float + default: 0.9 + min: 0.01 + max: 1.00 + help: + zh_Hans: 控制生成结果的随机性。数值越小,随机性越弱;数值越大,随机性越强。一般而言,top_p 和 temperature 两个参数选择一个进行调整即可。 + en_US: Control the randomness of generated results. The smaller the value, the weaker the randomness; the larger the value, the stronger the randomness. Generally speaking, you can adjust one of the two parameters top_p and temperature. +pricing: + input: '0.99' + output: '0.99' + unit: '0.000001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/yi/yi.yaml b/api/core/model_runtime/model_providers/yi/yi.yaml index de741afb10..393526c31e 100644 --- a/api/core/model_runtime/model_providers/yi/yi.yaml +++ b/api/core/model_runtime/model_providers/yi/yi.yaml @@ -20,6 +20,7 @@ supported_model_types: - llm configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -39,3 +40,57 @@ provider_credential_schema: placeholder: zh_Hans: Base URL, e.g. https://api.lingyiwanwu.com/v1 en_US: Base URL, e.g. https://api.lingyiwanwu.com/v1 +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 5a34a3d593..f629b62fd5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -3,7 +3,7 @@ from typing import Optional from zhipuai import ZhipuAI -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index b1d6f93cff..992415657e 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,14 +1,14 @@ from typing import Optional -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.models.document import Document -from core.rag.rerank.constants.rerank_mode import RerankMode from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights -from core.rag.rerank.rerank_model import RerankModelRunner -from core.rag.rerank.weight_rerank import WeightRerankRunner +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_factory import RerankRunnerFactory +from core.rag.rerank.rerank_type import RerankMode class DataPostProcessor: @@ -47,11 +47,12 @@ class DataPostProcessor: tenant_id: str, reranking_model: Optional[dict] = None, weights: Optional[dict] = None, - ) -> Optional[RerankModelRunner | WeightRerankRunner]: + ) -> Optional[BaseRerankRunner]: if reranking_mode == RerankMode.WEIGHTED_SCORE.value and weights: - return WeightRerankRunner( - tenant_id, - Weights( + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, + tenant_id=tenant_id, + weights=Weights( vector_setting=VectorSetting( vector_weight=weights["vector_setting"]["vector_weight"], embedding_provider_name=weights["vector_setting"]["embedding_provider_name"], @@ -62,23 +63,33 @@ class DataPostProcessor: ), ), ) + return runner elif reranking_mode == RerankMode.RERANKING_MODEL.value: - if reranking_model: - try: - model_manager = ModelManager() - rerank_model_instance = model_manager.get_model_instance( - tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], - model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], - ) - except InvokeAuthorizationError: - return None - return RerankModelRunner(rerank_model_instance) - return None + rerank_model_instance = self._get_rerank_model_instance(tenant_id, reranking_model) + if rerank_model_instance is None: + return None + runner = RerankRunnerFactory.create_rerank_runner( + runner_type=reranking_mode, rerank_model_instance=rerank_model_instance + ) + return runner return None def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: if reorder_enabled: return ReorderRunner() return None + + def _get_rerank_model_instance(self, tenant_id: str, reranking_model: Optional[dict]) -> ModelInstance | None: + if reranking_model: + try: + model_manager = ModelManager() + rerank_model_instance = model_manager.get_model_instance( + tenant_id=tenant_id, + provider=reranking_model["reranking_provider_name"], + model_type=ModelType.RERANK, + model=reranking_model["reranking_model_name"], + ) + return rerank_model_instance + except InvokeAuthorizationError: + return None + return None diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index d3fd0c672a..3affbd2d0a 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -6,7 +6,7 @@ from flask import Flask, current_app from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.rerank.constants.rerank_mode import RerankMode +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py index 6dcd98dcfd..c77cb87376 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector.py @@ -9,10 +9,10 @@ _import_err_msg = ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 543cfa67b3..1d4bfef76d 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -12,10 +12,10 @@ from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index 610aa498ab..a9e1486edd 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -6,10 +6,10 @@ from chromadb import QueryResult, Settings from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index f420373d5b..052a187225 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -9,11 +9,11 @@ from elasticsearch import Elasticsearch from flask import current_app from pydantic import BaseModel, model_validator -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index bdca59f869..080a1ef567 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -7,11 +7,11 @@ from pymilvus import MilvusClient, MilvusException from pymilvus.milvus_client import IndexParams from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/myscale/myscale_vector.py b/api/core/rag/datasource/vdb/myscale/myscale_vector.py index b30aa7ca22..1fca926a2d 100644 --- a/api/core/rag/datasource/vdb/myscale/myscale_vector.py +++ b/api/core/rag/datasource/vdb/myscale/myscale_vector.py @@ -8,10 +8,10 @@ from clickhouse_connect import get_client from pydantic import BaseModel from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 8d2e0a86ab..0e0f107268 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -9,11 +9,11 @@ from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 84a4381cd1..4ced5d61e5 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -13,10 +13,10 @@ from nltk.corpus import stopwords from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index a82a9b96dd..7cbbdcc81f 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -12,11 +12,11 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.pgvecto_rs.collection import CollectionORM from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset @@ -216,7 +216,7 @@ class PGVectoRSFactory(AbstractVectorFactory): else: dataset_id = dataset.id collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.WEAVIATE, collection_name)) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.PGVECTO_RS, collection_name)) dim = len(embeddings.embed_query("pgvecto_rs")) return PGVectoRS( diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 6f336d27e7..40a9cdd136 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -8,10 +8,10 @@ import psycopg2.pool from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index f418e3ca05..69d2aa4f76 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -20,11 +20,11 @@ from qdrant_client.http.models import ( from qdrant_client.local.qdrant_local import QdrantLocal from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 13a63784be..f373dcfeab 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -8,9 +8,9 @@ from sqlalchemy import text as sql_text from sqlalchemy.dialects.postgresql import JSON, TEXT from sqlalchemy.orm import Session -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from models.dataset import Dataset try: diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index 39e3a7f6cf..f971a9c5eb 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -8,10 +8,10 @@ from tcvectordb.model import index as vdb_index from tcvectordb.model.document import Filter from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py index 7837c5a4aa..1147e35ce8 100644 --- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py +++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py @@ -9,10 +9,10 @@ from sqlalchemy import text as sql_text from sqlalchemy.orm import Session, declarative_base from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index d4e131f257..20cbc69942 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -2,12 +2,12 @@ from abc import ABC, abstractmethod from typing import Any, Optional from configs import dify_config -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.cached_embedding import CacheEmbedding +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py index 5f60f10acb..4f927f2899 100644 --- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py @@ -14,11 +14,11 @@ from volcengine.viking_db import ( ) from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field as vdb_Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 4009efe7a7..649cfbfea8 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -7,11 +7,11 @@ import weaviate from pydantic import BaseModel, model_validator from configs import dify_config -from core.rag.datasource.entity.embedding import Embeddings from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings from core.rag.models.document import Document from extensions.ext_redis import redis_client from models.dataset import Dataset diff --git a/api/core/rag/embedding/__init__.py b/api/core/rag/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py similarity index 97% rename from api/core/embedding/cached_embedding.py rename to api/core/rag/embedding/cached_embedding.py index 31d2171e72..b3e93ce760 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -6,11 +6,11 @@ import numpy as np from sqlalchemy.exc import IntegrityError from configs import dify_config -from core.embedding.embedding_constant import EmbeddingInputType +from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.rag.datasource.entity.embedding import Embeddings +from core.rag.embedding.embedding_base import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/rag/datasource/entity/embedding.py b/api/core/rag/embedding/embedding_base.py similarity index 90% rename from api/core/rag/datasource/entity/embedding.py rename to api/core/rag/embedding/embedding_base.py index 126c1a3723..9f232ab910 100644 --- a/api/core/rag/datasource/entity/embedding.py +++ b/api/core/rag/embedding/embedding_base.py @@ -7,10 +7,12 @@ class Embeddings(ABC): @abstractmethod def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs.""" + raise NotImplementedError @abstractmethod def embed_query(self, text: str) -> list[float]: """Embed query text.""" + raise NotImplementedError async def aembed_documents(self, texts: list[str]) -> list[list[float]]: """Asynchronous Embed search docs.""" diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py new file mode 100644 index 0000000000..818b04b2ff --- /dev/null +++ b/api/core/rag/rerank/rerank_base.py @@ -0,0 +1,26 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.rag.models.document import Document + + +class BaseRerankRunner(ABC): + @abstractmethod + def run( + self, + query: str, + documents: list[Document], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> list[Document]: + """ + Run rerank model + :param query: search query + :param documents: documents for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id if needed + :return: + """ + raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_factory.py b/api/core/rag/rerank/rerank_factory.py new file mode 100644 index 0000000000..1a3cf85736 --- /dev/null +++ b/api/core/rag/rerank/rerank_factory.py @@ -0,0 +1,16 @@ +from core.rag.rerank.rerank_base import BaseRerankRunner +from core.rag.rerank.rerank_model import RerankModelRunner +from core.rag.rerank.rerank_type import RerankMode +from core.rag.rerank.weight_rerank import WeightRerankRunner + + +class RerankRunnerFactory: + @staticmethod + def create_rerank_runner(runner_type: str, *args, **kwargs) -> BaseRerankRunner: + match runner_type: + case RerankMode.RERANKING_MODEL.value: + return RerankModelRunner(*args, **kwargs) + case RerankMode.WEIGHTED_SCORE.value: + return WeightRerankRunner(*args, **kwargs) + case _: + raise ValueError(f"Unknown runner type: {runner_type}") diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 27f86aed34..40ebf0befd 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -2,9 +2,10 @@ from typing import Optional from core.model_manager import ModelInstance from core.rag.models.document import Document +from core.rag.rerank.rerank_base import BaseRerankRunner -class RerankModelRunner: +class RerankModelRunner(BaseRerankRunner): def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance diff --git a/api/core/rag/rerank/constants/rerank_mode.py b/api/core/rag/rerank/rerank_type.py similarity index 100% rename from api/core/rag/rerank/constants/rerank_mode.py rename to api/core/rag/rerank/rerank_type.py diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 16d6b879a4..2e3fbe04e2 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,15 +4,16 @@ from typing import Optional import numpy as np -from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler +from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner -class WeightRerankRunner: +class WeightRerankRunner(BaseRerankRunner): def __init__(self, tenant_id: str, weights: Weights) -> None: self.tenant_id = tenant_id self.weights = weights diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 6bab9a09d8..336588e3e2 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -61,6 +61,7 @@ - vectorizer - qrcode - tianditu +- aliyuque - google_translate - hap - json_process diff --git a/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg b/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg new file mode 100644 index 0000000000..82b23ebbc6 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/_assets/icon.svg @@ -0,0 +1,32 @@ + + 绿 lgo + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aliyuque/aliyuque.py b/api/core/tools/provider/builtin/aliyuque/aliyuque.py new file mode 100644 index 0000000000..56eac1a4b5 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/aliyuque.py @@ -0,0 +1,19 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AliYuqueProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + token = credentials.get("token") + if not token: + raise ToolProviderCredentialValidationError("token is required") + + try: + resp = AliYuqueTool.auth(token) + if resp and resp.get("data", {}).get("id"): + return + + raise ToolProviderCredentialValidationError(resp) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml b/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml new file mode 100644 index 0000000000..73d39aa96c --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/aliyuque.yaml @@ -0,0 +1,29 @@ +identity: + author: 佐井 + name: aliyuque + label: + en_US: yuque + zh_Hans: 语雀 + pt_BR: yuque + description: + en_US: Yuque, https://www.yuque.com. + zh_Hans: 语雀,https://www.yuque.com。 + pt_BR: Yuque, https://www.yuque.com. + icon: icon.svg + tags: + - productivity + - search +credentials_for_provider: + token: + type: secret-input + required: true + label: + en_US: Yuque Team Token + zh_Hans: 语雀团队Token + placeholder: + en_US: Please input your Yuque team token + zh_Hans: 请输入你的语雀团队Token + help: + en_US: Get Alibaba Yuque team token + zh_Hans: 先获取语雀团队Token + url: https://www.yuque.com/settings/tokens diff --git a/api/core/tools/provider/builtin/aliyuque/tools/base.py b/api/core/tools/provider/builtin/aliyuque/tools/base.py new file mode 100644 index 0000000000..fb7e219bff --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/base.py @@ -0,0 +1,50 @@ +""" +语雀客户端 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 09:45:20" + +from typing import Any + +import requests + + +class AliYuqueTool: + # yuque service url + server_url = "https://www.yuque.com" + + @staticmethod + def auth(token): + session = requests.Session() + session.headers.update({"Accept": "application/json", "X-Auth-Token": token}) + login = session.request("GET", AliYuqueTool.server_url + "/api/v2/user") + login.raise_for_status() + resp = login.json() + return resp + + def request(self, method: str, token, tool_parameters: dict[str, Any], path: str) -> str: + if not token: + raise Exception("token is required") + session = requests.Session() + session.headers.update({"accept": "application/json", "X-Auth-Token": token}) + new_params = {**tool_parameters} + # 找出需要替换的变量 + replacements = {k: v for k, v in new_params.items() if f"{{{k}}}" in path} + + # 替换 path 中的变量 + for key, value in replacements.items(): + path = path.replace(f"{{{key}}}", str(value)) + del new_params[key] # 从 kwargs 中删除已经替换的变量 + # 请求接口 + if method.upper() in {"POST", "PUT"}: + session.headers.update( + { + "Content-Type": "application/json", + } + ) + response = session.request(method.upper(), self.server_url + path, json=new_params) + else: + response = session.request(method, self.server_url + path, params=new_params) + response.raise_for_status() + return response.text diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.py b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py new file mode 100644 index 0000000000..feadc29258 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.py @@ -0,0 +1,22 @@ +""" +创建文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 10:45:20" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueCreateDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message(self.request("POST", token, tool_parameters, "/api/v2/repos/{book_id}/docs")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml new file mode 100644 index 0000000000..b9d1c60327 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/create_document.yaml @@ -0,0 +1,99 @@ +identity: + name: aliyuque_create_document + author: 佐井 + label: + en_US: Create Document + zh_Hans: 创建文档 + icon: icon.svg +description: + human: + en_US: Creates a new document within a knowledge base without automatic addition to the table of contents. Requires a subsequent call to the "knowledge base directory update API". Supports setting visibility, format, and content. # 接口英文描述 + zh_Hans: 在知识库中创建新文档,但不会自动加入目录,需额外调用“知识库目录更新接口”。允许设置公开性、格式及正文内容。 + llm: Creates docs in a KB. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库ID + human_description: + en_US: The unique identifier of the knowledge base where the document will be created. + zh_Hans: 文档将被创建的知识库的唯一标识。 + llm_description: ID of the target knowledge base. + + - name: title + type: string + required: false + form: llm + label: + en_US: Title + zh_Hans: 标题 + human_description: + en_US: The title of the document, defaults to 'Untitled' if not provided. + zh_Hans: 文档标题,默认为'无标题'如未提供。 + llm_description: Title of the document, defaults to 'Untitled'. + + - name: public + type: select + required: false + form: llm + options: + - value: 0 + label: + en_US: Private + zh_Hans: 私密 + - value: 1 + label: + en_US: Public + zh_Hans: 公开 + - value: 2 + label: + en_US: Enterprise-only + zh_Hans: 企业内公开 + label: + en_US: Visibility + zh_Hans: 公开性 + human_description: + en_US: Document visibility (0 Private, 1 Public, 2 Enterprise-only). + zh_Hans: 文档可见性(0 私密, 1 公开, 2 企业内公开)。 + llm_description: Doc visibility options, 0-private, 1-public, 2-enterprise. + + - name: format + type: select + required: false + form: llm + options: + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + - value: html + label: + en_US: html + zh_Hans: html + - value: lake + label: + en_US: lake + zh_Hans: lake + label: + en_US: Content Format + zh_Hans: 内容格式 + human_description: + en_US: Format of the document content (markdown, HTML, Lake). + zh_Hans: 文档内容格式(markdown, HTML, Lake)。 + llm_description: Content format choices, markdown, HTML, Lake. + + - name: body + type: string + required: true + form: llm + label: + en_US: Body Content + zh_Hans: 正文内容 + human_description: + en_US: The actual content of the document. + zh_Hans: 文档的实际内容。 + llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py new file mode 100644 index 0000000000..74c731a944 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python3 +""" +删除文档 +""" + +__author__ = "佐井" +__created__ = "2024-09-17 22:04" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDeleteDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("DELETE", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml new file mode 100644 index 0000000000..87372c5350 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/delete_document.yaml @@ -0,0 +1,37 @@ +identity: + name: aliyuque_delete_document + author: 佐井 + label: + en_US: Delete Document + zh_Hans: 删除文档 + icon: icon.svg +description: + human: + en_US: Delete Document + zh_Hans: 根据id删除文档 + llm: Delete document. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库ID + human_description: + en_US: The unique identifier of the knowledge base where the document will be created. + zh_Hans: 文档将被创建的知识库的唯一标识。 + llm_description: ID of the target knowledge base. + + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID or 路径 + human_description: + en_US: Document ID or path. + zh_Hans: 文档 ID or 路径。 + llm_description: Document ID or path. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py new file mode 100644 index 0000000000..02bf603a24 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.py @@ -0,0 +1,24 @@ +""" +获取知识库首页 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 22:57:14" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeBookIndexPageTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("GET", token, tool_parameters, "/api/v2/repos/{group_login}/{book_slug}/index_page") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml new file mode 100644 index 0000000000..5e490725d1 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_index_page.yaml @@ -0,0 +1,38 @@ +identity: + name: aliyuque_describe_book_index_page + author: 佐井 + label: + en_US: Get Repo Index Page + zh_Hans: 获取知识库首页 + icon: icon.svg + +description: + human: + en_US: Retrieves the homepage of a knowledge base within a group, supporting both book ID and group login with book slug access. + zh_Hans: 获取团队中知识库的首页信息,可通过书籍ID或团队登录名与书籍路径访问。 + llm: Fetches the knowledge base homepage using group and book identifiers with support for alternate access paths. + +parameters: + - name: group_login + type: string + required: true + form: llm + label: + en_US: Group Login + zh_Hans: 团队登录名 + human_description: + en_US: The login name of the group that owns the knowledge base. + zh_Hans: 拥有该知识库的团队登录名。 + llm_description: Team login identifier for the knowledge base owner. + + - name: book_slug + type: string + required: true + form: llm + label: + en_US: Book Slug + zh_Hans: 知识库路径 + human_description: + en_US: The unique slug representing the path of the knowledge base. + zh_Hans: 知识库的唯一路径标识。 + llm_description: Unique path identifier for the knowledge base. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py new file mode 100644 index 0000000000..fcfe449c6d --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +""" +获取知识库目录 +""" + +__author__ = "佐井" +__created__ = "2024-09-17 15:17:11" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message(self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml new file mode 100644 index 0000000000..0c2bd22132 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_book_table_of_contents.yaml @@ -0,0 +1,25 @@ +identity: + name: aliyuque_describe_book_table_of_contents + author: 佐井 + label: + en_US: Get Book's Table of Contents + zh_Hans: 获取知识库的目录 + icon: icon.svg +description: + human: + en_US: Get Book's Table of Contents. + zh_Hans: 获取知识库的目录。 + llm: Get Book's Table of Contents. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Book ID + zh_Hans: 知识库 ID + human_description: + en_US: Book ID. + zh_Hans: 知识库 ID。 + llm_description: Book ID. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py new file mode 100644 index 0000000000..1e70593879 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.py @@ -0,0 +1,61 @@ +""" +获取文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-02 07:11:45" + +import json +from typing import Any, Union +from urllib.parse import urlparse + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeDocumentContentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + new_params = {**tool_parameters} + token = new_params.pop("token") + if not token or token.lower() == "none": + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + new_params = {**tool_parameters} + url = new_params.pop("url") + if not url or not url.startswith("http"): + raise Exception("url is not valid") + + parsed_url = urlparse(url) + path_parts = parsed_url.path.strip("/").split("/") + if len(path_parts) < 3: + raise Exception("url is not correct") + doc_id = path_parts[-1] + book_slug = path_parts[-2] + group_id = path_parts[-3] + + # 1. 请求首页信息,获取book_id + new_params["group_login"] = group_id + new_params["book_slug"] = book_slug + index_page = json.loads( + self.request("GET", token, new_params, "/api/v2/repos/{group_login}/{book_slug}/index_page") + ) + book_id = index_page.get("data", {}).get("book", {}).get("id") + if not book_id: + raise Exception(f"can not parse book_id from {index_page}") + # 2. 获取文档内容 + new_params["book_id"] = book_id + new_params["id"] = doc_id + data = self.request("GET", token, new_params, "/api/v2/repos/{book_id}/docs/{id}") + data = json.loads(data) + body_only = tool_parameters.get("body_only") or "" + if body_only.lower() == "true": + return self.create_text_message(data.get("data").get("body")) + else: + raw = data.get("data") + del raw["body_lake"] + del raw["body_html"] + return self.create_text_message(json.dumps(data)) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml new file mode 100644 index 0000000000..6116886a96 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_document_content.yaml @@ -0,0 +1,50 @@ +identity: + name: aliyuque_describe_document_content + author: 佐井 + label: + en_US: Fetch Document Content + zh_Hans: 获取文档内容 + icon: icon.svg + +description: + human: + en_US: Retrieves document content from Yuque based on the provided document URL, which can be a normal or shared link. + zh_Hans: 根据提供的语雀文档地址(支持正常链接或分享链接)获取文档内容。 + llm: Fetches Yuque document content given a URL. + +parameters: + - name: url + type: string + required: true + form: llm + label: + en_US: Document URL + zh_Hans: 文档地址 + human_description: + en_US: The URL of the document to retrieve content from, can be normal or shared. + zh_Hans: 需要获取内容的文档地址,可以是正常链接或分享链接。 + llm_description: URL of the Yuque document to fetch content. + + - name: body_only + type: string + required: false + form: llm + label: + en_US: return body content only + zh_Hans: 仅返回body内容 + human_description: + en_US: true:Body content only, false:Full response with metadata. + zh_Hans: true:仅返回body内容,不返回其他元数据,false:返回所有元数据。 + llm_description: true:Body content only, false:Full response with metadata. + + - name: token + type: secret-input + required: false + form: llm + label: + en_US: Yuque API Token + zh_Hans: 语雀接口Token + human_description: + en_US: The token for calling the Yuque API defaults to the Yuque token bound to the current tool if not provided. + zh_Hans: 调用语雀接口的token,如果不传则默认为当前工具绑定的语雀Token。 + llm_description: If the token for calling the Yuque API is not provided, it will default to the Yuque token bound to the current tool. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py new file mode 100644 index 0000000000..ed1b2a8643 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.py @@ -0,0 +1,24 @@ +""" +获取文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-01 10:45:20" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueDescribeDocumentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("GET", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml new file mode 100644 index 0000000000..5156345d71 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/describe_documents.yaml @@ -0,0 +1,38 @@ +identity: + name: aliyuque_describe_documents + author: 佐井 + label: + en_US: Get Doc Detail + zh_Hans: 获取文档详情 + icon: icon.svg + +description: + human: + en_US: Retrieves detailed information of a specific document identified by its ID or path within a knowledge base. + zh_Hans: 根据知识库ID和文档ID或路径获取文档详细信息。 + llm: Fetches detailed doc info using ID/path from a knowledge base; supports doc lookup in Yuque. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库 ID + human_description: + en_US: Identifier for the knowledge base where the document resides. + zh_Hans: 文档所属知识库的唯一标识。 + llm_description: ID of the knowledge base holding the document. + + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID 或路径 + human_description: + en_US: The unique identifier or path of the document to retrieve. + zh_Hans: 需要获取的文档的ID或其在知识库中的路径。 + llm_description: Unique doc ID or its path for retrieval. diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py new file mode 100644 index 0000000000..932559445e --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +""" +获取知识库目录 +""" + +__author__ = "佐井" +__created__ = "2024-09-17 15:17:11" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class YuqueDescribeBookTableOfContentsTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> (Union)[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + + doc_ids = tool_parameters.get("doc_ids") + if doc_ids: + doc_ids = [int(doc_id.strip()) for doc_id in doc_ids.split(",")] + tool_parameters["doc_ids"] = doc_ids + + return self.create_text_message(self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/toc")) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml new file mode 100644 index 0000000000..f0c0024f17 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_book_table_of_contents.yaml @@ -0,0 +1,222 @@ +identity: + name: aliyuque_update_book_table_of_contents + author: 佐井 + label: + en_US: Update Book's Table of Contents + zh_Hans: 更新知识库目录 + icon: icon.svg +description: + human: + en_US: Update Book's Table of Contents. + zh_Hans: 更新知识库目录。 + llm: Update Book's Table of Contents. + +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Book ID + zh_Hans: 知识库 ID + human_description: + en_US: Book ID. + zh_Hans: 知识库 ID。 + llm_description: Book ID. + + - name: action + type: select + required: true + form: llm + options: + - value: appendNode + label: + en_US: appendNode + zh_Hans: appendNode + pt_BR: appendNode + - value: prependNode + label: + en_US: prependNode + zh_Hans: prependNode + pt_BR: prependNode + - value: editNode + label: + en_US: editNode + zh_Hans: editNode + pt_BR: editNode + - value: editNode + label: + en_US: removeNode + zh_Hans: removeNode + pt_BR: removeNode + label: + en_US: Action Type + zh_Hans: 操作 + human_description: + en_US: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). + zh_Hans: 操作,创建场景下不支持同级头插 prependNode,删除节点不会删除关联文档,删除节点时action_mode=sibling (删除当前节点), action_mode=child (删除当前节点及子节点) + llm_description: In the operation scenario, sibling node prepending is not supported, deleting a node doesn't remove associated documents, and node deletion has two modes, 'sibling' (delete current node) and 'child' (delete current node and its children). + + + - name: action_mode + type: select + required: false + form: llm + options: + - value: sibling + label: + en_US: sibling + zh_Hans: 同级 + pt_BR: sibling + - value: child + label: + en_US: child + zh_Hans: 子集 + pt_BR: child + label: + en_US: Action Type + zh_Hans: 操作 + human_description: + en_US: Operation mode (sibling:same level, child:child level). + zh_Hans: 操作模式 (sibling:同级, child:子级)。 + llm_description: Operation mode (sibling:same level, child:child level). + + - name: target_uuid + type: string + required: false + form: llm + label: + en_US: Target node UUID + zh_Hans: 目标节点 UUID + human_description: + en_US: Target node UUID, defaults to root node if left empty. + zh_Hans: 目标节点 UUID, 不填默认为根节点。 + llm_description: Target node UUID, defaults to root node if left empty. + + - name: node_uuid + type: string + required: false + form: llm + label: + en_US: Node UUID + zh_Hans: 操作节点 UUID + human_description: + en_US: Operation node UUID [required for move/update/delete]. + zh_Hans: 操作节点 UUID [移动/更新/删除必填]。 + llm_description: Operation node UUID [required for move/update/delete]. + + - name: doc_ids + type: string + required: false + form: llm + label: + en_US: Document IDs + zh_Hans: 文档id列表 + human_description: + en_US: Document IDs [required for creating documents], separate multiple IDs with ','. + zh_Hans: 文档 IDs [创建文档必填],多个用','分隔。 + llm_description: Document IDs [required for creating documents], separate multiple IDs with ','. + + + - name: type + type: select + required: false + form: llm + default: DOC + options: + - value: DOC + label: + en_US: DOC + zh_Hans: 文档 + pt_BR: DOC + - value: LINK + label: + en_US: LINK + zh_Hans: 链接 + pt_BR: LINK + - value: TITLE + label: + en_US: TITLE + zh_Hans: 分组 + pt_BR: TITLE + label: + en_US: Node type + zh_Hans: 操节点类型 + human_description: + en_US: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). + zh_Hans: 操节点类型 [创建必填] (DOC:文档, LINK:外链, TITLE:分组)。 + llm_description: Node type [required for creation] (DOC:document, LINK:external link, TITLE:group). + + - name: title + type: string + required: false + form: llm + label: + en_US: Node Name + zh_Hans: 节点名称 + human_description: + en_US: Node name [required for creating groups/external links]. + zh_Hans: 节点名称 [创建分组/外链必填]。 + llm_description: Node name [required for creating groups/external links]. + + - name: url + type: string + required: false + form: llm + label: + en_US: Node URL + zh_Hans: 节点URL + human_description: + en_US: Node URL [required for creating external links]. + zh_Hans: 节点 URL [创建外链必填]。 + llm_description: Node URL [required for creating external links]. + + + - name: open_window + type: select + required: false + form: llm + default: 0 + options: + - value: 0 + label: + en_US: DOC + zh_Hans: Current Page + pt_BR: DOC + - value: 1 + label: + en_US: LINK + zh_Hans: New Page + pt_BR: LINK + label: + en_US: Open in new window + zh_Hans: 是否新窗口打开 + human_description: + en_US: Open in new window [optional for external links] (0:open in current page, 1:open in new window). + zh_Hans: 是否新窗口打开 [外链选填] (0:当前页打开, 1:新窗口打开)。 + llm_description: Open in new window [optional for external links] (0:open in current page, 1:open in new window). + + + - name: visible + type: select + required: false + form: llm + default: 1 + options: + - value: 0 + label: + en_US: Invisible + zh_Hans: 隐藏 + pt_BR: Invisible + - value: 1 + label: + en_US: Visible + zh_Hans: 可见 + pt_BR: Visible + label: + en_US: Visibility + zh_Hans: 是否可见 + human_description: + en_US: Visibility (0:invisible, 1:visible). + zh_Hans: 是否可见 (0:不可见, 1:可见)。 + llm_description: Visibility (0:invisible, 1:visible). diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.py b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py new file mode 100644 index 0000000000..0c6e0205e1 --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.py @@ -0,0 +1,24 @@ +""" +更新文档 +""" + +__author__ = "佐井" +__created__ = "2024-06-19 16:50:07" + +from typing import Any, Union + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.provider.builtin.aliyuque.tools.base import AliYuqueTool +from core.tools.tool.builtin_tool import BuiltinTool + + +class AliYuqueUpdateDocumentTool(AliYuqueTool, BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + token = self.runtime.credentials.get("token", None) + if not token: + raise Exception("token is required") + return self.create_text_message( + self.request("PUT", token, tool_parameters, "/api/v2/repos/{book_id}/docs/{id}") + ) diff --git a/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml new file mode 100644 index 0000000000..87f88c9b1b --- /dev/null +++ b/api/core/tools/provider/builtin/aliyuque/tools/update_document.yaml @@ -0,0 +1,87 @@ +identity: + name: aliyuque_update_document + author: 佐井 + label: + en_US: Update Document + zh_Hans: 更新文档 + icon: icon.svg +description: + human: + en_US: Update an existing document within a specified knowledge base by providing the document ID or path. + zh_Hans: 通过提供文档ID或路径,更新指定知识库中的现有文档。 + llm: Update doc in a knowledge base via ID/path. +parameters: + - name: book_id + type: number + required: true + form: llm + label: + en_US: Knowledge Base ID + zh_Hans: 知识库 ID + human_description: + en_US: The unique identifier of the knowledge base where the document resides. + zh_Hans: 文档所属知识库的ID。 + llm_description: ID of the knowledge base holding the doc. + - name: id + type: string + required: true + form: llm + label: + en_US: Document ID or Path + zh_Hans: 文档 ID 或 路径 + human_description: + en_US: The unique identifier or the path of the document to be updated. + zh_Hans: 要更新的文档的唯一ID或路径。 + llm_description: Doc's ID or path for update. + + - name: title + type: string + required: false + form: llm + label: + en_US: Title + zh_Hans: 标题 + human_description: + en_US: The title of the document, defaults to 'Untitled' if not provided. + zh_Hans: 文档标题,默认为'无标题'如未提供。 + llm_description: Title of the document, defaults to 'Untitled'. + + - name: format + type: select + required: false + form: llm + options: + - value: markdown + label: + en_US: markdown + zh_Hans: markdown + pt_BR: markdown + - value: html + label: + en_US: html + zh_Hans: html + pt_BR: html + - value: lake + label: + en_US: lake + zh_Hans: lake + pt_BR: lake + label: + en_US: Content Format + zh_Hans: 内容格式 + human_description: + en_US: Format of the document content (markdown, HTML, Lake). + zh_Hans: 文档内容格式(markdown, HTML, Lake)。 + llm_description: Content format choices, markdown, HTML, Lake. + + - name: body + type: string + required: true + form: llm + label: + en_US: Body Content + zh_Hans: 正文内容 + human_description: + en_US: The actual content of the document. + zh_Hans: 文档的实际内容。 + llm_description: Content of the document. diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.py b/api/core/tools/provider/builtin/comfyui/comfyui.py index 7013a0b93c..bab690af82 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.py +++ b/api/core/tools/provider/builtin/comfyui/comfyui.py @@ -1,17 +1,21 @@ from typing import Any +import websocket +from yarl import URL + from core.tools.errors import ToolProviderCredentialValidationError -from core.tools.provider.builtin.comfyui.tools.comfyui_stable_diffusion import ComfyuiStableDiffusionTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController class ComfyUIProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: + ws = websocket.WebSocket() + base_url = URL(credentials.get("base_url")) + ws_address = f"ws://{base_url.authority}/ws?clientId=test123" + try: - ComfyuiStableDiffusionTool().fork_tool_runtime( - runtime={ - "credentials": credentials, - } - ).validate_models() + ws.connect(ws_address) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + raise ToolProviderCredentialValidationError(f"can not connect to {ws_address}") + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/comfyui.yaml b/api/core/tools/provider/builtin/comfyui/comfyui.yaml index 3891eebf3a..24ae43cd44 100644 --- a/api/core/tools/provider/builtin/comfyui/comfyui.yaml +++ b/api/core/tools/provider/builtin/comfyui/comfyui.yaml @@ -4,11 +4,9 @@ identity: label: en_US: ComfyUI zh_Hans: ComfyUI - pt_BR: ComfyUI description: en_US: ComfyUI is a tool for generating images which can be deployed locally. zh_Hans: ComfyUI 是一个可以在本地部署的图片生成的工具。 - pt_BR: ComfyUI is a tool for generating images which can be deployed locally. icon: icon.png tags: - image @@ -17,26 +15,9 @@ credentials_for_provider: type: text-input required: true label: - en_US: Base URL - zh_Hans: ComfyUI服务器的Base URL - pt_BR: Base URL + en_US: The URL of ComfyUI Server + zh_Hans: ComfyUI服务器的URL placeholder: en_US: Please input your ComfyUI server's Base URL zh_Hans: 请输入你的 ComfyUI 服务器的 Base URL - pt_BR: Please input your ComfyUI server's Base URL - model: - type: text-input - required: true - label: - en_US: Model with suffix - zh_Hans: 模型, 需要带后缀 - pt_BR: Model with suffix - placeholder: - en_US: Please input your model - zh_Hans: 请输入你的模型名称 - pt_BR: Please input your model - help: - en_US: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors - zh_Hans: ComfyUI服务器的模型名称, 比如 xxx.safetensors - pt_BR: The checkpoint name of the ComfyUI server, e.g. xxx.safetensors - url: https://github.com/comfyanonymous/ComfyUI#installing + url: https://docs.dify.ai/guides/tools/tool-configuration/comfyui diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py new file mode 100644 index 0000000000..a41d34d40f --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py @@ -0,0 +1,105 @@ +import json +import random +import uuid + +import httpx +from websocket import WebSocket +from yarl import URL + + +class ComfyUiClient: + def __init__(self, base_url: str): + self.base_url = URL(base_url) + + def get_history(self, prompt_id: str): + res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id}) + history = res.json()[prompt_id] + return history + + def get_image(self, filename: str, subfolder: str, folder_type: str): + response = httpx.get( + str(self.base_url / "view"), + params={"filename": filename, "subfolder": subfolder, "type": folder_type}, + ) + return response.content + + def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False): + # plan to support img2img in dify 0.10.0 + with open(input_path, "rb") as file: + files = {"image": (name, file, "image/png")} + data = {"type": image_type, "overwrite": str(overwrite).lower()} + + res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files) + return res + + def queue_prompt(self, client_id: str, prompt: dict): + res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt}) + prompt_id = res.json()["prompt_id"] + return prompt_id + + def open_websocket_connection(self): + client_id = str(uuid.uuid4()) + ws = WebSocket() + ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}" + ws.connect(ws_address) + return ws, client_id + + def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""): + """ + find the first KSampler, then can find the prompt node through it. + """ + prompt = origin_prompt.copy() + id_to_class_type = {id: details["class_type"] for id, details in prompt.items()} + k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0] + prompt.get(k_sampler)["inputs"]["seed"] = random.randint(10**14, 10**15 - 1) + positive_input_id = prompt.get(k_sampler)["inputs"]["positive"][0] + prompt.get(positive_input_id)["inputs"]["text"] = positive_prompt + + if negative_prompt != "": + negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0] + prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt + return prompt + + def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str): + node_ids = list(prompt.keys()) + finished_nodes = [] + + while True: + out = ws.recv() + if isinstance(out, str): + message = json.loads(out) + if message["type"] == "progress": + data = message["data"] + current_step = data["value"] + print("In K-Sampler -> Step: ", current_step, " of: ", data["max"]) + if message["type"] == "execution_cached": + data = message["data"] + for itm in data["nodes"]: + if itm not in finished_nodes: + finished_nodes.append(itm) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + if message["type"] == "executing": + data = message["data"] + if data["node"] not in finished_nodes: + finished_nodes.append(data["node"]) + print("Progress: ", len(finished_nodes), "/", len(node_ids), " Tasks done") + + if data["node"] is None and data["prompt_id"] == prompt_id: + break # Execution is done + else: + continue + + def generate_image_by_prompt(self, prompt: dict): + try: + ws, client_id = self.open_websocket_connection() + prompt_id = self.queue_prompt(client_id, prompt) + self.track_progress(prompt, ws, prompt_id) + history = self.get_history(prompt_id) + images = [] + for output in history["outputs"].values(): + for img in output.get("images", []): + image_data = self.get_image(img["filename"], img["subfolder"], img["type"]) + images.append(image_data) + return images + finally: + ws.close() diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml index 4f4a6942b3..75fe746965 100644 --- a/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_stable_diffusion.yaml @@ -1,10 +1,10 @@ identity: - name: txt2img workflow + name: txt2img author: Qun label: - en_US: Txt2Img Workflow - zh_Hans: Txt2Img Workflow - pt_BR: Txt2Img Workflow + en_US: Txt2Img + zh_Hans: Txt2Img + pt_BR: Txt2Img description: human: en_US: a pre-defined comfyui workflow that can use one model and up to 3 loras to generate images. Support SD1.5, SDXL, SD3 and FLUX which contain text encoders/clip, but does not support models that requires a triple clip loader. diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py new file mode 100644 index 0000000000..e4df9f8c3b --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py @@ -0,0 +1,32 @@ +import json +from typing import Any + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +from .comfyui_client import ComfyUiClient + + +class ComfyUIWorkflowTool(BuiltinTool): + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + comfyui = ComfyUiClient(self.runtime.credentials["base_url"]) + + positive_prompt = tool_parameters.get("positive_prompt") + negative_prompt = tool_parameters.get("negative_prompt") + workflow = tool_parameters.get("workflow_json") + + try: + origin_prompt = json.loads(workflow) + except: + return self.create_text_message("the Workflow JSON is not correct") + + prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt) + images = comfyui.generate_image_by_prompt(prompt) + result = [] + for img in images: + result.append( + self.create_blob_message( + blob=img, meta={"mime_type": "image/png"}, save_as=self.VariableKey.IMAGE.value + ) + ) + return result diff --git a/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml new file mode 100644 index 0000000000..6342d6d468 --- /dev/null +++ b/api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml @@ -0,0 +1,35 @@ +identity: + name: workflow + author: hjlarry + label: + en_US: workflow + zh_Hans: 工作流 +description: + human: + en_US: Run ComfyUI workflow. + zh_Hans: 运行ComfyUI工作流。 + llm: Run ComfyUI workflow. +parameters: + - name: positive_prompt + type: string + label: + en_US: Prompt + zh_Hans: 提示词 + llm_description: Image prompt, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: negative_prompt + type: string + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English. + form: llm + - name: workflow_json + type: string + required: true + label: + en_US: Workflow JSON + human_description: + en_US: exported from ComfyUI workflow + zh_Hans: 从ComfyUI的工作流中导出 + form: form diff --git a/api/tests/integration_tests/controllers/app_fixture.py b/api/tests/integration_tests/controllers/app_fixture.py new file mode 100644 index 0000000000..93065ee95c --- /dev/null +++ b/api/tests/integration_tests/controllers/app_fixture.py @@ -0,0 +1,24 @@ +import pytest + +from app_factory import create_app + +mock_user = type( + "MockUser", + (object,), + { + "is_authenticated": True, + "id": "123", + "is_editor": True, + "is_dataset_editor": True, + "status": "active", + "get_id": "123", + "current_tenant_id": "9d2074fc-6f86-45a9-b09d-6ecc63b9056b", + }, +) + + +@pytest.fixture +def app(): + app = create_app() + app.config["LOGIN_DISABLED"] = True + return app diff --git a/api/tests/integration_tests/controllers/test_controllers.py b/api/tests/integration_tests/controllers/test_controllers.py new file mode 100644 index 0000000000..6371694694 --- /dev/null +++ b/api/tests/integration_tests/controllers/test_controllers.py @@ -0,0 +1,10 @@ +from unittest.mock import patch + +from app_fixture import app, mock_user + + +def test_post_requires_login(app): + with app.test_client() as client: + with patch("flask_login.utils._get_user", mock_user): + response = client.get("/console/api/data-source/integrates") + assert response.status_code == 200 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py new file mode 100644 index 0000000000..33c803e8e1 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/wenxin/test_rerank.py @@ -0,0 +1,21 @@ +import os +from time import sleep + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel + + +def test_invoke_bce_reranker_base_v1(): + sleep(3) + model = WenxinRerankModel() + + response = model.invoke( + model="bce-reranker-base_v1", + credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")}, + query="What is Deep Learning?", + docs=["Deep Learning is ...", "My Book is ..."], + user="abc-123", + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 2 diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index ea1e7d4c80..d6797ed28e 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -9,4 +9,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/weaviate \ api/tests/integration_tests/vdb/elasticsearch \ api/tests/integration_tests/vdb/vikingdb \ - api/tests/integration_tests/vdb/baidu + api/tests/integration_tests/vdb/baidu \ + api/tests/integration_tests/vdb/tcvectordb diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index c8d6750fc8..9d6345aa6c 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -87,15 +87,15 @@ const Apps = () => { localStorage.removeItem(NEED_REFRESH_APP_LIST_KEY) mutate() } - }, []) + }, [mutate, t]) useEffect(() => { if (isCurrentWorkspaceDatasetOperator) return router.replace('/datasets') - }, [isCurrentWorkspaceDatasetOperator]) + }, [router, isCurrentWorkspaceDatasetOperator]) - const hasMore = data?.at(-1)?.has_more ?? true useEffect(() => { + const hasMore = data?.at(-1)?.has_more ?? true let observer: IntersectionObserver | undefined if (anchorRef.current) { observer = new IntersectionObserver((entries) => { @@ -105,7 +105,7 @@ const Apps = () => { observer.observe(anchorRef.current) } return () => observer?.disconnect() - }, [isLoading, setSize, anchorRef, mutate, hasMore]) + }, [isLoading, setSize, anchorRef, mutate, data]) const { run: handleSearch } = useDebounceFn(() => { setSearchKeywords(keywords) diff --git a/web/app/components/base/features/feature-panel/index.tsx b/web/app/components/base/features/feature-panel/index.tsx new file mode 100644 index 0000000000..72799ef2fc --- /dev/null +++ b/web/app/components/base/features/feature-panel/index.tsx @@ -0,0 +1,119 @@ +import { + memo, + useMemo, +} from 'react' +import { useTranslation } from 'react-i18next' +import type { OnFeaturesChange } from '../types' +import { useFeatures } from '../hooks' +import FileUpload from './file-upload' +import OpeningStatement from './opening-statement' +import type { OpeningStatementProps } from './opening-statement' +import SuggestedQuestionsAfterAnswer from './suggested-questions-after-answer' +import TextToSpeech from './text-to-speech' +import SpeechToText from './speech-to-text' +import Citation from './citation' +import Moderation from './moderation' +import type { InputVar } from '@/app/components/workflow/types' + +export type FeaturePanelProps = { + onChange?: OnFeaturesChange + openingStatementProps: OpeningStatementProps + disabled?: boolean + workflowVariables: InputVar[] +} +const FeaturePanel = ({ + onChange, + openingStatementProps, + disabled, + workflowVariables, +}: FeaturePanelProps) => { + const { t } = useTranslation() + const features = useFeatures(s => s.features) + + const showAdvanceFeature = useMemo(() => { + return features.opening?.enabled || features.suggested?.enabled || features.speech2text?.enabled || features.text2speech?.enabled || features.citation?.enabled + }, [features]) + + const showToolFeature = useMemo(() => { + return features.moderation?.enabled + }, [features]) + + return ( +
+ + { + showAdvanceFeature && ( +
+
+
+ {t('appDebug.feature.groupChat.title')} +
+
+
+
+ { + features.opening?.enabled && ( + + ) + } + { + features.suggested?.enabled && ( + + ) + } + { + features.text2speech?.enabled && ( + + ) + } + { + features.speech2text?.enabled && ( + + ) + } + { + features.citation?.enabled && ( + + ) + } +
+
+ ) + } + { + showToolFeature && ( +
+
+
+ {t('appDebug.feature.groupChat.title')} +
+
+
+
+ { + features.moderation?.enabled && ( + + ) + } +
+
+ ) + } +
+ ) +} +export default memo(FeaturePanel) diff --git a/web/app/components/base/features/feature-panel/opening-statement/index.tsx b/web/app/components/base/features/feature-panel/opening-statement/index.tsx new file mode 100644 index 0000000000..1f102700ad --- /dev/null +++ b/web/app/components/base/features/feature-panel/opening-statement/index.tsx @@ -0,0 +1,328 @@ +/* eslint-disable multiline-ternary */ +'use client' +import type { FC } from 'react' +import React, { useEffect, useRef, useState } from 'react' +import produce from 'immer' +import { + RiAddLine, + RiDeleteBinLine, +} from '@remixicon/react' +import { useTranslation } from 'react-i18next' +import { useBoolean } from 'ahooks' +import { ReactSortable } from 'react-sortablejs' +import { + useFeatures, + useFeaturesStore, +} from '../../hooks' +import type { OnFeaturesChange } from '../../types' +import cn from '@/utils/classnames' +import Panel from '@/app/components/app/configuration/base/feature-panel' +import Button from '@/app/components/base/button' +import OperationBtn from '@/app/components/app/configuration/base/operation-btn' +import { getInputKeys } from '@/app/components/base/block-input' +import ConfirmAddVar from '@/app/components/app/configuration/config-prompt/confirm-add-var' +import { getNewVar } from '@/utils/var' +import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight' +import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' + +const MAX_QUESTION_NUM = 5 + +export type OpeningStatementProps = { + onChange?: OnFeaturesChange + readonly?: boolean + promptVariables?: PromptVariable[] + onAutoAddPromptVariable: (variable: PromptVariable[]) => void + workflowVariables?: InputVar[] +} + +// regex to match the {{}} and replace it with a span +const regex = /\{\{([^}]+)\}\}/g + +const OpeningStatement: FC = ({ + onChange, + readonly, + promptVariables = [], + onAutoAddPromptVariable, + workflowVariables = [], +}) => { + const { t } = useTranslation() + const featureStore = useFeaturesStore() + const openingStatement = useFeatures(s => s.features.opening) + const value = openingStatement?.opening_statement || '' + const suggestedQuestions = openingStatement?.suggested_questions || [] + const [notIncludeKeys, setNotIncludeKeys] = useState([]) + + const hasValue = !!(value || '').trim() + const inputRef = useRef(null) + + const [isFocus, { setTrue: didSetFocus, setFalse: setBlur }] = useBoolean(false) + + const setFocus = () => { + didSetFocus() + setTimeout(() => { + const input = inputRef.current + if (input) { + input.focus() + input.setSelectionRange(input.value.length, input.value.length) + } + }, 0) + } + + const [tempValue, setTempValue] = useState(value) + useEffect(() => { + setTempValue(value || '') + }, [value]) + + const [tempSuggestedQuestions, setTempSuggestedQuestions] = useState(suggestedQuestions || []) + const notEmptyQuestions = tempSuggestedQuestions.filter(question => !!question && question.trim()) + const coloredContent = (tempValue || '') + .replace(//g, '>') + .replace(regex, varHighlightHTML({ name: '$1' })) // `{{$1}}` + .replace(/\n/g, '
') + + const handleEdit = () => { + if (readonly) + return + setFocus() + } + + const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) + + const handleCancel = () => { + setBlur() + setTempValue(value) + setTempSuggestedQuestions(suggestedQuestions) + } + + const handleConfirm = () => { + const keys = getInputKeys(tempValue) + const promptKeys = promptVariables.map(item => item.key) + const workflowVariableKeys = workflowVariables.map(item => item.variable) + let notIncludeKeys: string[] = [] + + if (promptKeys.length === 0 && workflowVariables.length === 0) { + if (keys.length > 0) + notIncludeKeys = keys + } + else { + if (workflowVariables.length > 0) + notIncludeKeys = keys.filter(key => !workflowVariableKeys.includes(key)) + + else notIncludeKeys = keys.filter(key => !promptKeys.includes(key)) + } + + if (notIncludeKeys.length > 0) { + setNotIncludeKeys(notIncludeKeys) + showConfirmAddVar() + return + } + setBlur() + const { getState } = featureStore! + const { + features, + setFeatures, + } = getState() + + const newFeatures = produce(features, (draft) => { + if (draft.opening) { + draft.opening.opening_statement = tempValue + draft.opening.suggested_questions = tempSuggestedQuestions + } + }) + setFeatures(newFeatures) + + if (onChange) + onChange(newFeatures) + } + + const cancelAutoAddVar = () => { + const { getState } = featureStore! + const { + features, + setFeatures, + } = getState() + + const newFeatures = produce(features, (draft) => { + if (draft.opening) + draft.opening.opening_statement = tempValue + }) + setFeatures(newFeatures) + + if (onChange) + onChange(newFeatures) + hideConfirmAddVar() + setBlur() + } + + const autoAddVar = () => { + const { getState } = featureStore! + const { + features, + setFeatures, + } = getState() + + const newFeatures = produce(features, (draft) => { + if (draft.opening) + draft.opening.opening_statement = tempValue + }) + setFeatures(newFeatures) + if (onChange) + onChange(newFeatures) + onAutoAddPromptVariable([...notIncludeKeys.map(key => getNewVar(key, 'string'))]) + hideConfirmAddVar() + setBlur() + } + + const headerRight = !readonly ? ( + isFocus ? ( +
+ + +
+ ) : ( + + ) + ) : null + + const renderQuestions = () => { + return isFocus ? ( +
+
+
+
{t('appDebug.openingStatement.openingQuestion')}
+
·
+
{tempSuggestedQuestions.length}/{MAX_QUESTION_NUM}
+
+
+
+ { + return { + id: index, + name, + } + })} + setList={list => setTempSuggestedQuestions(list.map(item => item.name))} + handle='.handle' + ghostClass="opacity-50" + animation={150} + > + {tempSuggestedQuestions.map((question, index) => { + return ( +
+
+ + + +
+ { + const value = e.target.value + setTempSuggestedQuestions(tempSuggestedQuestions.map((item, i) => { + if (index === i) + return value + + return item + })) + }} + className={'w-full overflow-x-auto pl-1.5 pr-8 text-sm leading-9 text-gray-900 border-0 grow h-9 bg-transparent focus:outline-none cursor-pointer rounded-lg'} + /> + +
{ + setTempSuggestedQuestions(tempSuggestedQuestions.filter((_, i) => index !== i)) + }} + > + +
+
+ ) + })}
+ {tempSuggestedQuestions.length < MAX_QUESTION_NUM && ( +
{ setTempSuggestedQuestions([...tempSuggestedQuestions, '']) }} + className='mt-1 flex items-center h-9 px-3 gap-2 rounded-lg cursor-pointer text-gray-400 bg-gray-100 hover:bg-gray-200'> + +
{t('appDebug.variableConfig.addOption')}
+
+ )} +
+ ) : ( +
+ {notEmptyQuestions.map((question, index) => { + return ( +
+ {question} +
+ ) + })} +
+ ) + } + + return ( + + + + } + headerRight={headerRight} + hasHeaderBottomBorder={!hasValue} + isFocus={isFocus} + > +
+ {(hasValue || (!hasValue && isFocus)) ? ( + <> + {isFocus + ? ( +
+ +
+ ) + : ( +
+ )} + {renderQuestions()} + ) : ( +
{t('appDebug.openingStatement.noDataPlaceHolder')}
+ )} + + {isShowConfirmAddVar && ( + + )} + +
+
+ ) +} +export default React.memo(OpeningStatement) diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx index 1e7c141059..ab6b3ec6db 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/index.tsx @@ -10,11 +10,13 @@ import type { OnFeaturesChange } from '@/app/components/base/features/types' import { FeatureEnum } from '@/app/components/base/features/types' import { useModalContext } from '@/context/modal-context' import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' type Props = { disabled?: boolean onChange?: OnFeaturesChange promptVariables?: PromptVariable[] + workflowVariables?: InputVar[] onAutoAddPromptVariable?: (variable: PromptVariable[]) => void } @@ -22,6 +24,7 @@ const ConversationOpener = ({ disabled, onChange, promptVariables, + workflowVariables, onAutoAddPromptVariable, }: Props) => { const { t } = useTranslation() @@ -40,6 +43,7 @@ const ConversationOpener = ({ payload: { ...opening, promptVariables, + workflowVariables, onAutoAddPromptVariable, }, onSaveCallback: (newOpening) => { diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx index f8636d7d2a..9f25d0fa11 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -10,6 +10,7 @@ import ConfirmAddVar from '@/app/components/app/configuration/config-prompt/conf import type { OpeningStatement } from '@/app/components/base/features/types' import { getInputKeys } from '@/app/components/base/block-input' import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' import { getNewVar } from '@/utils/var' type OpeningSettingModalProps = { @@ -17,6 +18,7 @@ type OpeningSettingModalProps = { onSave: (newState: OpeningStatement) => void onCancel: () => void promptVariables?: PromptVariable[] + workflowVariables?: InputVar[] onAutoAddPromptVariable?: (variable: PromptVariable[]) => void } @@ -27,6 +29,7 @@ const OpeningSettingModal = ({ onSave, onCancel, promptVariables = [], + workflowVariables = [], onAutoAddPromptVariable, }: OpeningSettingModalProps) => { const { t } = useTranslation() @@ -42,14 +45,17 @@ const OpeningSettingModal = ({ if (!ignoreVariablesCheck) { const keys = getInputKeys(tempValue) const promptKeys = promptVariables.map(item => item.key) + const workflowVariableKeys = workflowVariables.map(item => item.variable) let notIncludeKeys: string[] = [] - if (promptKeys.length === 0) { + if (promptKeys.length === 0 && workflowVariables.length === 0) { if (keys.length > 0) notIncludeKeys = keys } else { - notIncludeKeys = keys.filter(key => !promptKeys.includes(key)) + if (workflowVariables.length > 0) + notIncludeKeys = keys.filter(key => !workflowVariableKeys.includes(key)) + else notIncludeKeys = keys.filter(key => !promptKeys.includes(key)) } if (notIncludeKeys.length > 0) { @@ -65,7 +71,7 @@ const OpeningSettingModal = ({ } }) onSave(newOpening) - }, [data, onSave, promptVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue]) + }, [data, onSave, promptVariables, workflowVariables, showConfirmAddVar, tempSuggestedQuestions, tempValue]) const cancelAutoAddVar = useCallback(() => { hideConfirmAddVar() @@ -74,12 +80,11 @@ const OpeningSettingModal = ({ const autoAddVar = useCallback(() => { onAutoAddPromptVariable?.([ - ...promptVariables, ...notIncludeKeys.map(key => getNewVar(key, 'string')), ]) hideConfirmAddVar() handleSave(true) - }, [handleSave, hideConfirmAddVar, notIncludeKeys, onAutoAddPromptVariable, promptVariables]) + }, [handleSave, hideConfirmAddVar, notIncludeKeys, onAutoAddPromptVariable]) const renderQuestions = () => { return ( @@ -189,7 +194,7 @@ const OpeningSettingModal = ({ {isShowConfirmAddVar && ( diff --git a/web/app/components/base/features/new-feature-panel/index.tsx b/web/app/components/base/features/new-feature-panel/index.tsx index abeaddbc33..dcdb1baefc 100644 --- a/web/app/components/base/features/new-feature-panel/index.tsx +++ b/web/app/components/base/features/new-feature-panel/index.tsx @@ -1,6 +1,7 @@ import React from 'react' import { useTranslation } from 'react-i18next' -import { RiCloseLine } from '@remixicon/react' +import { useContext } from 'use-context-selector' +import { RiCloseLine, RiInformation2Fill } from '@remixicon/react' import DialogWrapper from '@/app/components/base/features/new-feature-panel/dialog-wrapper' import { useDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' @@ -17,6 +18,9 @@ import ImageUpload from '@/app/components/base/features/new-feature-panel/image- import Moderation from '@/app/components/base/features/new-feature-panel/moderation' import AnnotationReply from '@/app/components/base/features/new-feature-panel/annotation-reply' import type { PromptVariable } from '@/models/debug' +import type { InputVar } from '@/app/components/workflow/types' +import I18n from '@/context/i18n' +import { LanguagesSupported } from '@/i18n/language' type Props = { show: boolean @@ -27,6 +31,7 @@ type Props = { inWorkflow?: boolean showFileUpload?: boolean promptVariables?: PromptVariable[] + workflowVariables?: InputVar[] onAutoAddPromptVariable?: (variable: PromptVariable[]) => void } @@ -39,9 +44,11 @@ const NewFeaturePanel = ({ inWorkflow = true, showFileUpload = true, promptVariables, + workflowVariables, onAutoAddPromptVariable, }: Props) => { const { t } = useTranslation() + const { locale } = useContext(I18n) const { data: speech2textDefaultModel } = useDefaultModel(ModelTypeEnum.speech2text) const { data: text2speechDefaultModel } = useDefaultModel(ModelTypeEnum.tts) @@ -62,6 +69,24 @@ const NewFeaturePanel = ({ {/* list */}
+ {showFileUpload && ( +
+
+
+
+ +
+
+ {isChatMode ? t('workflow.common.fileUploadTip') : t('workflow.common.ImageUploadLegacyTip')} + {t('workflow.common.featuresDocLink')} +
+
+
+ )} {!isChatMode && !inWorkflow && ( )} @@ -70,6 +95,7 @@ const NewFeaturePanel = ({ disabled={disabled} onChange={onChange} promptVariables={promptVariables} + workflowVariables={workflowVariables} onAutoAddPromptVariable={onAutoAddPromptVariable} /> )} diff --git a/web/app/components/workflow/features.tsx b/web/app/components/workflow/features.tsx index f9bc0a3472..b54ffdf167 100644 --- a/web/app/components/workflow/features.tsx +++ b/web/app/components/workflow/features.tsx @@ -2,12 +2,17 @@ import { memo, useCallback, } from 'react' +import { useNodes } from 'reactflow' import { useStore } from './store' import { useIsChatMode, useNodesReadOnly, useNodesSyncDraft, } from './hooks' +import { type CommonNodeType, type InputVar, InputVarType, type Node } from './types' +import useConfig from './nodes/start/use-config' +import type { StartNodeType } from './nodes/start/types' +import type { PromptVariable } from '@/models/debug' import NewFeaturePanel from '@/app/components/base/features/new-feature-panel' const Features = () => { @@ -15,6 +20,24 @@ const Features = () => { const isChatMode = useIsChatMode() const { nodesReadOnly } = useNodesReadOnly() const { handleSyncWorkflowDraft } = useNodesSyncDraft() + const nodes = useNodes() + + const startNode = nodes.find(node => node.data.type === 'start') + const { id, data } = startNode as Node + const { handleAddVariable } = useConfig(id, data) + + const handleAddOpeningStatementVariable = (variables: PromptVariable[]) => { + const newVariable = variables[0] + const startNodeVariable: InputVar = { + variable: newVariable.key, + label: newVariable.name, + type: InputVarType.textInput, + max_length: newVariable.max_length, + required: newVariable.required || false, + options: [], + } + handleAddVariable(startNodeVariable) + } const handleFeaturesChange = useCallback(() => { handleSyncWorkflowDraft() @@ -28,6 +51,8 @@ const Features = () => { disabled={nodesReadOnly} onChange={handleFeaturesChange} onClose={() => setShowFeaturesPanel(false)} + onAutoAddPromptVariable={handleAddOpeningStatementVariable} + workflowVariables={data.variables} /> ) } diff --git a/web/context/modal-context.tsx b/web/context/modal-context.tsx index 5fa525fcd1..2dfc08cf88 100644 --- a/web/context/modal-context.tsx +++ b/web/context/modal-context.tsx @@ -30,6 +30,7 @@ import type { ModelLoadBalancingModalProps } from '@/app/components/header/accou import ModelLoadBalancingModal from '@/app/components/header/account-setting/model-provider-page/provider-added-card/model-load-balancing-modal' import OpeningSettingModal from '@/app/components/base/features/new-feature-panel/conversation-opener/modal' import type { OpeningStatement } from '@/app/components/base/features/types' +import type { InputVar } from '@/app/components/workflow/types' export type ModalState = { payload: T @@ -64,6 +65,7 @@ export type ModalContextState = { setShowModelLoadBalancingEntryModal: Dispatch | null>> setShowOpeningModal: Dispatch void }> | null>> } @@ -105,6 +107,7 @@ export const ModalContextProvider = ({ const [showModelLoadBalancingEntryModal, setShowModelLoadBalancingEntryModal] = useState | null>(null) const [showOpeningModal, setShowOpeningModal] = useState void }> | null>(null) const searchParams = useSearchParams() @@ -332,6 +335,7 @@ export const ModalContextProvider = ({ onSave={handleSaveOpeningModal} onCancel={handleCancelOpeningModal} promptVariables={showOpeningModal.payload.promptVariables} + workflowVariables={showOpeningModal.payload.workflowVariables} onAutoAddPromptVariable={showOpeningModal.payload.onAutoAddPromptVariable} /> )} diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index 658c040e05..ea8355500a 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -20,6 +20,9 @@ const translation = { conversationLog: 'Conversation Log', features: 'Features', featuresDescription: 'Enhance web app user experience', + ImageUploadLegacyTip: 'You can now create file type variables in the start form. We will no longer support the image upload feature in the future. ', + fileUploadTip: 'Image upload features have been upgraded to file upload. ', + featuresDocLink: 'Learn more', debugAndPreview: 'Preview', restart: 'Restart', currentDraft: 'Current Draft', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index c09ee13d3b..515d0fe235 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -20,6 +20,9 @@ const translation = { conversationLog: '对话记录', features: '功能', featuresDescription: '增强 web app 用户体验', + ImageUploadLegacyTip: '现在可以在 start 表单中创建文件类型变量。未来我们将不继续支持图片上传功能。', + fileUploadTip: '图片上传功能已扩展为文件上传。', + featuresDocLink: '了解更多', debugAndPreview: '预览', restart: '重新开始', currentDraft: '当前草稿',