diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index b1cf41a226..e9c2b7b086 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -27,18 +27,17 @@ jobs: - name: Checkout code uses: actions/checkout@v4 - - name: Install Poetry - uses: abatilo/actions-poetry@v3 - - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - cache: 'poetry' cache-dependency-path: | api/pyproject.toml api/poetry.lock + - name: Install Poetry + uses: abatilo/actions-poetry@v3 + - name: Check Poetry lockfile run: | poetry check -C api --lock diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 719e6cfe90..246854cf0b 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -24,15 +24,15 @@ jobs: with: files: api/** - - name: Install Poetry - uses: abatilo/actions-poetry@v3 - - name: Set up Python uses: actions/setup-python@v5 if: steps.changed-files.outputs.any_changed == 'true' with: python-version: '3.10' + - name: Install Poetry + uses: abatilo/actions-poetry@v3 + - name: Python dependencies if: steps.changed-files.outputs.any_changed == 'true' run: poetry install -C api --only lint diff --git a/api/README.md b/api/README.md index bab33f9293..92cd88a6d4 100644 --- a/api/README.md +++ b/api/README.md @@ -85,3 +85,4 @@ cd ../ poetry run -C api bash dev/pytest/pytest_all_tests.sh ``` + diff --git a/api/configs/middleware/vdb/oracle_config.py b/api/configs/middleware/vdb/oracle_config.py index 44e2f13345..5d2cf67ba3 100644 --- a/api/configs/middleware/vdb/oracle_config.py +++ b/api/configs/middleware/vdb/oracle_config.py @@ -14,7 +14,7 @@ class OracleConfig(BaseSettings): default=None, ) - ORACLE_PORT: Optional[PositiveInt] = Field( + ORACLE_PORT: PositiveInt = Field( description="Port number on which the Oracle database server is listening (default is 1521)", default=1521, ) diff --git a/api/configs/middleware/vdb/pgvector_config.py b/api/configs/middleware/vdb/pgvector_config.py index 85f5dca7e2..4561a9a7ca 100644 --- a/api/configs/middleware/vdb/pgvector_config.py +++ b/api/configs/middleware/vdb/pgvector_config.py @@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings): default=None, ) - PGVECTOR_PORT: Optional[PositiveInt] = Field( + PGVECTOR_PORT: PositiveInt = Field( description="Port number on which the PostgreSQL server is listening (default is 5433)", default=5433, ) diff --git a/api/configs/middleware/vdb/pgvectors_config.py b/api/configs/middleware/vdb/pgvectors_config.py index 8d7a4b8d25..fa3bca5bb7 100644 --- a/api/configs/middleware/vdb/pgvectors_config.py +++ b/api/configs/middleware/vdb/pgvectors_config.py @@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings): default=None, ) - PGVECTO_RS_PORT: Optional[PositiveInt] = Field( + PGVECTO_RS_PORT: PositiveInt = Field( description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)", default=5431, ) diff --git a/api/configs/middleware/vdb/vikingdb_config.py b/api/configs/middleware/vdb/vikingdb_config.py index 5ad98d898a..3e718481dc 100644 --- a/api/configs/middleware/vdb/vikingdb_config.py +++ b/api/configs/middleware/vdb/vikingdb_config.py @@ -11,27 +11,39 @@ class VikingDBConfig(BaseModel): """ VIKINGDB_ACCESS_KEY: Optional[str] = Field( - default=None, description="The Access Key provided by Volcengine VikingDB for API authentication." + description="The Access Key provided by Volcengine VikingDB for API authentication." + "Refer to the following documentation for details on obtaining credentials:" + "https://www.volcengine.com/docs/6291/65568", + default=None, ) + VIKINGDB_SECRET_KEY: Optional[str] = Field( - default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication." + description="The Secret Key provided by Volcengine VikingDB for API authentication.", + default=None, ) - VIKINGDB_REGION: Optional[str] = Field( - default="cn-shanghai", + + VIKINGDB_REGION: str = Field( description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').", + default="cn-shanghai", ) - VIKINGDB_HOST: Optional[str] = Field( - default="api-vikingdb.mlp.cn-shanghai.volces.com", + + VIKINGDB_HOST: str = Field( description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \ 'api-vikingdb.mlp.cn-shanghai.volces.com')", + default="api-vikingdb.mlp.cn-shanghai.volces.com", ) - VIKINGDB_SCHEME: Optional[str] = Field( - default="http", + + VIKINGDB_SCHEME: str = Field( description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').", + default="http", ) - VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field( - default=30, description="The connection timeout of the Volcengine VikingDB service." + + VIKINGDB_CONNECTION_TIMEOUT: int = Field( + description="The connection timeout of the Volcengine VikingDB service.", + default=30, ) - VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field( - default=30, description="The socket timeout of the Volcengine VikingDB service." + + VIKINGDB_SOCKET_TIMEOUT: int = Field( + description="The socket timeout of the Volcengine VikingDB service.", + default=30, ) diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 6e6d8c0bd7..5c9bcef84c 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,88 +1,24 @@ -import logging +from flask_restful import Resource -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound - -import services from controllers.console import api -from controllers.console.app.error import ( - CompletionRequestError, - ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, - ProviderQuotaExceededError, -) -from controllers.console.datasets.error import DatasetNotInitializedError +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.errors.error import ( - LLMBadRequestError, - ModelCurrentlyNotSupportError, - ProviderTokenNotInitError, - QuotaExceededError, -) -from core.model_runtime.errors.invoke import InvokeError -from fields.hit_testing_fields import hit_testing_record_fields from libs.login import login_required -from services.dataset_service import DatasetService -from services.hit_testing_service import HitTestingService -class HitTestingApi(Resource): +class HitTestingApi(Resource, DatasetsHitTestingBase): @setup_required @login_required @account_initialization_required def post(self, dataset_id): dataset_id_str = str(dataset_id) - dataset = DatasetService.get_dataset(dataset_id_str) - if dataset is None: - raise NotFound("Dataset not found.") + dataset = self.get_and_validate_dataset(dataset_id_str) + args = self.parse_args() + self.hit_testing_args_check(args) - try: - DatasetService.check_dataset_permission(dataset, current_user) - except services.errors.account.NoPermissionError as e: - raise Forbidden(str(e)) - - parser = reqparse.RequestParser() - parser.add_argument("query", type=str, location="json") - parser.add_argument("retrieval_model", type=dict, required=False, location="json") - parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") - args = parser.parse_args() - - HitTestingService.hit_testing_args_check(args) - - try: - response = HitTestingService.retrieve( - dataset=dataset, - query=args["query"], - account=current_user, - retrieval_model=args["retrieval_model"], - external_retrieval_model=args["external_retrieval_model"], - limit=10, - ) - - return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} - except services.errors.index.IndexNotInitializedError: - raise DatasetNotInitializedError() - except ProviderTokenNotInitError as ex: - raise ProviderNotInitializeError(ex.description) - except QuotaExceededError: - raise ProviderQuotaExceededError() - except ModelCurrentlyNotSupportError: - raise ProviderModelCurrentlyNotSupportError() - except LLMBadRequestError: - raise ProviderNotInitializeError( - "No Embedding Model or Reranking Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except InvokeError as e: - raise CompletionRequestError(e.description) - except ValueError as e: - raise ValueError(str(e)) - except Exception as e: - logging.exception("Hit testing failed.") - raise InternalServerError(str(e)) + return self.perform_hit_testing(dataset, args) api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py new file mode 100644 index 0000000000..3b4c076863 --- /dev/null +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -0,0 +1,85 @@ +import logging + +from flask_login import current_user +from flask_restful import marshal, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services.dataset_service +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import DatasetNotInitializedError +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from core.model_runtime.errors.invoke import InvokeError +from fields.hit_testing_fields import hit_testing_record_fields +from services.dataset_service import DatasetService +from services.hit_testing_service import HitTestingService + + +class DatasetsHitTestingBase: + @staticmethod + def get_and_validate_dataset(dataset_id: str): + dataset = DatasetService.get_dataset(dataset_id) + if dataset is None: + raise NotFound("Dataset not found.") + + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + return dataset + + @staticmethod + def hit_testing_args_check(args): + HitTestingService.hit_testing_args_check(args) + + @staticmethod + def parse_args(): + parser = reqparse.RequestParser() + + parser.add_argument("query", type=str, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, location="json") + parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") + return parser.parse_args() + + @staticmethod + def perform_hit_testing(dataset, args): + try: + response = HitTestingService.retrieve( + dataset=dataset, + query=args["query"], + account=current_user, + retrieval_model=args["retrieval_model"], + external_retrieval_model=args["external_retrieval_model"], + limit=10, + ) + return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} + except services.errors.index.IndexNotInitializedError: + raise DatasetNotInitializedError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except LLMBadRequestError: + raise ProviderNotInitializeError( + "No Embedding Model or Reranking Model available. Please configure a valid provider " + "in the Settings -> Model Provider." + ) + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise ValueError(str(e)) + except Exception as e: + logging.exception("Hit testing failed.") + raise InternalServerError(str(e)) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index ad39c160ac..d6ab96c329 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -5,7 +5,6 @@ from libs.external_api import ExternalApi bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) - from . import index from .app import app, audio, completion, conversation, file, message, workflow -from .dataset import dataset, document, segment +from .dataset import dataset, document, hit_testing, segment diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5c3601cf23..8d8e356c4c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,7 +4,6 @@ from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound import services -from constants import UUID_NIL from controllers.service_api import api from controllers.service_api.app.error import ( AppUnavailableError, @@ -108,7 +107,6 @@ class ChatApi(Resource): parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") - parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json") args = parser.parse_args() diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py new file mode 100644 index 0000000000..9c9a4302c9 --- /dev/null +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -0,0 +1,17 @@ +from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase +from controllers.service_api import api +from controllers.service_api.wraps import DatasetApiResource + + +class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + def post(self, tenant_id, dataset_id): + dataset_id_str = str(dataset_id) + + dataset = self.get_and_validate_dataset(dataset_id_str) + args = self.parse_args() + self.hit_testing_args_check(args) + + return self.perform_hit_testing(dataset, args) + + +api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 450e30b926..2dea84bc71 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -10,6 +10,7 @@ from flask import Flask, current_app from pydantic import ValidationError import contexts +from constants import UUID_NIL from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner @@ -132,7 +133,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, - parent_message_id=args.get("parent_message_id"), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=stream, invoke_from=invoke_from, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index c452c3cb46..90666fb2e9 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager @@ -140,7 +141,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, - parent_message_id=args.get("parent_message_id"), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, stream=stream, invoke_from=invoke_from, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index dffd5eeac4..8fdfdcbe76 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError +from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom @@ -138,7 +139,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), query=query, files=file_objs, - parent_message_id=args.get("parent_message_id"), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, invoke_from=invoke_from, extras=extras, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index dca2e13f61..f2eba29323 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,8 +2,9 @@ from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator +from constants import UUID_NIL from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle from core.file.models import File @@ -116,13 +117,36 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity): model_config = ConfigDict(protected_namespaces=()) -class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): +class ConversationAppGenerateEntity(AppGenerateEntity): + """ + Base entity for conversation-based app generation. + """ + + conversation_id: Optional[str] = None + parent_message_id: Optional[str] = Field( + default=None, + description=( + "Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API." + "For service API, we need to ensure its forward compatibility, " + "so passing in the parent_message_id as request arg is not supported for now. " + "It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages." + ), + ) + + @field_validator("parent_message_id") + @classmethod + def validate_parent_message_id(cls, v, info: ValidationInfo): + if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL: + raise ValueError("parent_message_id should be UUID_NIL for service API") + return v + + +class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity): """ Chat Application Generate Entity. """ - conversation_id: Optional[str] = None - parent_message_id: Optional[str] = None + pass class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): @@ -133,16 +157,15 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): pass -class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): +class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity): """ Agent Chat Application Generate Entity. """ - conversation_id: Optional[str] = None - parent_message_id: Optional[str] = None + pass -class AdvancedChatAppGenerateEntity(AppGenerateEntity): +class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity): """ Advanced Chat Application Generate Entity. """ @@ -150,8 +173,6 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity): # app config app_config: WorkflowUIBasedAppConfig - conversation_id: Optional[str] = None - parent_message_id: Optional[str] = None workflow_run_id: Optional[str] = None query: str diff --git a/api/core/model_runtime/model_providers/fireworks/fireworks.yaml b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml index cdb87a55e9..ddbaa54eb1 100644 --- a/api/core/model_runtime/model_providers/fireworks/fireworks.yaml +++ b/api/core/model_runtime/model_providers/fireworks/fireworks.yaml @@ -18,6 +18,7 @@ supported_model_types: - text-embedding configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: fireworks_api_key @@ -28,3 +29,75 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model URL + zh_Hans: 模型URL + placeholder: + en_US: Enter your Model URL + zh_Hans: 输入模型URL + credential_form_schemas: + - variable: model_label_zh_Hanns + label: + zh_Hans: 模型中文名称 + en_US: The zh_Hans of Model + required: true + type: text-input + placeholder: + zh_Hans: 在此输入您的模型中文名称 + en_US: Enter your zh_Hans of Model + - variable: model_label_en_US + label: + zh_Hans: 模型英文名称 + en_US: The en_US of Model + required: true + type: text-input + placeholder: + zh_Hans: 在此输入您的模型英文名称 + en_US: Enter your en_US of Model + - variable: fireworks_api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/fireworks/llm/llm.py b/api/core/model_runtime/model_providers/fireworks/llm/llm.py index 24aad9c4d3..ffe1ad5fcb 100644 --- a/api/core/model_runtime/model_providers/fireworks/llm/llm.py +++ b/api/core/model_runtime/model_providers/fireworks/llm/llm.py @@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho from openai.types.chat.chat_completion_message import FunctionCall from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, @@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.fireworks._common import _CommonFireworks @@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel): num_tokens += self._get_num_tokens_by_gpt2(required_field) return num_tokens + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject( + en_US=credentials.get("model_label_en_US", model), + zh_Hans=credentials.get("model_label_zh_Hanns", model), + ), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "function_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 4096)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="Top K"), + type=ParameterType.FLOAT, + ), + ], + ) diff --git a/api/core/model_runtime/model_providers/fireworks/llm/qwen2p5-72b-instruct.yaml b/api/core/model_runtime/model_providers/fireworks/llm/qwen2p5-72b-instruct.yaml new file mode 100644 index 0000000000..9728364340 --- /dev/null +++ b/api/core/model_runtime/model_providers/fireworks/llm/qwen2p5-72b-instruct.yaml @@ -0,0 +1,46 @@ +model: accounts/fireworks/models/qwen2p5-72b-instruct +label: + zh_Hans: Qwen2.5 72B Instruct + en_US: Qwen2.5 72B Instruct +model_type: llm +features: + - agent-thought + - tool-call +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + - name: max_tokens + use_template: max_tokens + - name: context_length_exceeded_behavior + default: None + label: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + help: + zh_Hans: 上下文长度超出行为 + en_US: Context Length Exceeded Behavior + type: string + options: + - None + - truncate + - error + - name: response_format + use_template: response_format +pricing: + input: '0.9' + output: '0.9' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 74d2a221d1..d031bfa04d 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -61,7 +61,8 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel): url = f"{self.api_base}?GroupId={group_id}" headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} - data = {"model": "embo-01", "texts": texts, "type": "db"} + embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query" + data = {"model": "embo-01", "texts": texts, "type": embedding_type} try: response = post(url, headers=headers, data=dumps(data)) diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index 3c99f33be6..f1a6ade91f 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -1,8 +1,8 @@ from typing import Any from configs import dify_config -from core.rag.datasource.keyword.jieba.jieba import Jieba from core.rag.datasource.keyword.keyword_base import BaseKeyword +from core.rag.datasource.keyword.keyword_type import KeyWordType from core.rag.models.document import Document from models.dataset import Dataset @@ -13,16 +13,19 @@ class Keyword: self._keyword_processor = self._init_keyword() def _init_keyword(self) -> BaseKeyword: - config = dify_config - keyword_type = config.KEYWORD_STORE + keyword_type = dify_config.KEYWORD_STORE + keyword_factory = self.get_keyword_factory(keyword_type) + return keyword_factory(self._dataset) - if not keyword_type: - raise ValueError("Keyword store must be specified.") + @staticmethod + def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]: + match keyword_type: + case KeyWordType.JIEBA: + from core.rag.datasource.keyword.jieba.jieba import Jieba - if keyword_type == "jieba": - return Jieba(dataset=self._dataset) - else: - raise ValueError(f"Keyword store {keyword_type} is not supported.") + return Jieba + case _: + raise ValueError(f"Keyword store {keyword_type} is not supported.") def create(self, texts: list[Document], **kwargs): self._keyword_processor.create(texts, **kwargs) diff --git a/api/core/rag/datasource/keyword/keyword_type.py b/api/core/rag/datasource/keyword/keyword_type.py new file mode 100644 index 0000000000..d6deba3fb0 --- /dev/null +++ b/api/core/rag/datasource/keyword/keyword_type.py @@ -0,0 +1,5 @@ +from enum import Enum + + +class KeyWordType(str, Enum): + JIEBA = "jieba" diff --git a/api/extensions/storage/aliyun_oss_storage.py b/api/extensions/storage/aliyun_oss_storage.py index 53bd399d6d..ae6911e945 100644 --- a/api/extensions/storage/aliyun_oss_storage.py +++ b/api/extensions/storage/aliyun_oss_storage.py @@ -1,5 +1,4 @@ from collections.abc import Generator -from contextlib import closing import oss2 as aliyun_s3 from flask import Flask @@ -34,15 +33,15 @@ class AliyunOssStorage(BaseStorage): self.client.put_object(self.__wrapper_folder_filename(filename), data) def load_once(self, filename: str) -> bytes: - with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: - data = obj.read() + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + data = obj.read() return data def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: - with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: - while chunk := obj.read(4096): - yield chunk + obj = self.client.get_object(self.__wrapper_folder_filename(filename)) + while chunk := obj.read(4096): + yield chunk return generate() diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index 38f823763f..507a303223 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,6 +1,5 @@ import logging from collections.abc import Generator -from contextlib import closing import boto3 from botocore.client import Config @@ -55,8 +54,7 @@ class AwsS3Storage(BaseStorage): def load_once(self, filename: str) -> bytes: try: - with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") @@ -67,9 +65,8 @@ class AwsS3Storage(BaseStorage): def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: try: - with closing(self.client) as client: - response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") @@ -79,16 +76,14 @@ class AwsS3Storage(BaseStorage): return generate() def download(self, filename, target_filepath): - with closing(self.client) as client: - client.download_file(self.bucket_name, filename, target_filepath) + self.client.download_file(self.bucket_name, filename, target_filepath) def exists(self, filename): - with closing(self.client) as client: - try: - client.head_object(Bucket=self.bucket_name, Key=filename) - return True - except: - return False + try: + self.client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except: + return False def delete(self, filename): self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/extensions/storage/google_cloud_storage.py b/api/extensions/storage/google_cloud_storage.py index d9c74b8d40..2d1224fd74 100644 --- a/api/extensions/storage/google_cloud_storage.py +++ b/api/extensions/storage/google_cloud_storage.py @@ -2,7 +2,6 @@ import base64 import io import json from collections.abc import Generator -from contextlib import closing from flask import Flask from google.cloud import storage as google_cloud_storage @@ -43,7 +42,7 @@ class GoogleCloudStorage(BaseStorage): def generate(filename: str = filename) -> Generator: bucket = self.client.get_bucket(self.bucket_name) blob = bucket.get_blob(filename) - with closing(blob.open(mode="rb")) as blob_stream: + with blob.open(mode="rb") as blob_stream: while chunk := blob_stream.read(4096): yield chunk diff --git a/api/extensions/storage/oracle_oci_storage.py b/api/extensions/storage/oracle_oci_storage.py index 6934583567..5295dbdca2 100644 --- a/api/extensions/storage/oracle_oci_storage.py +++ b/api/extensions/storage/oracle_oci_storage.py @@ -1,5 +1,4 @@ from collections.abc import Generator -from contextlib import closing import boto3 from botocore.exceptions import ClientError @@ -28,8 +27,7 @@ class OracleOCIStorage(BaseStorage): def load_once(self, filename: str) -> bytes: try: - with closing(self.client) as client: - data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() + data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") @@ -40,9 +38,8 @@ class OracleOCIStorage(BaseStorage): def load_stream(self, filename: str) -> Generator: def generate(filename: str = filename) -> Generator: try: - with closing(self.client) as client: - response = client.get_object(Bucket=self.bucket_name, Key=filename) - yield from response["Body"].iter_chunks() + response = self.client.get_object(Bucket=self.bucket_name, Key=filename) + yield from response["Body"].iter_chunks() except ClientError as ex: if ex.response["Error"]["Code"] == "NoSuchKey": raise FileNotFoundError("File not found") @@ -52,16 +49,14 @@ class OracleOCIStorage(BaseStorage): return generate() def download(self, filename, target_filepath): - with closing(self.client) as client: - client.download_file(self.bucket_name, filename, target_filepath) + self.client.download_file(self.bucket_name, filename, target_filepath) def exists(self, filename): - with closing(self.client) as client: - try: - client.head_object(Bucket=self.bucket_name, Key=filename) - return True - except: - return False + try: + self.client.head_object(Bucket=self.bucket_name, Key=filename) + return True + except: + return False def delete(self, filename): self.client.delete_object(Bucket=self.bucket_name, Key=filename) diff --git a/api/services/auth/api_key_auth_factory.py b/api/services/auth/api_key_auth_factory.py index 36387e9c2e..f91c448fb9 100644 --- a/api/services/auth/api_key_auth_factory.py +++ b/api/services/auth/api_key_auth_factory.py @@ -1,15 +1,25 @@ -from services.auth.firecrawl import FirecrawlAuth -from services.auth.jina import JinaAuth +from services.auth.api_key_auth_base import ApiKeyAuthBase +from services.auth.auth_type import AuthType class ApiKeyAuthFactory: def __init__(self, provider: str, credentials: dict): - if provider == "firecrawl": - self.auth = FirecrawlAuth(credentials) - elif provider == "jinareader": - self.auth = JinaAuth(credentials) - else: - raise ValueError("Invalid provider") + auth_factory = self.get_apikey_auth_factory(provider) + self.auth = auth_factory(credentials) def validate_credentials(self): return self.auth.validate_credentials() + + @staticmethod + def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]: + match provider: + case AuthType.FIRECRAWL: + from services.auth.firecrawl.firecrawl import FirecrawlAuth + + return FirecrawlAuth + case AuthType.JINA: + from services.auth.jina.jina import JinaAuth + + return JinaAuth + case _: + raise ValueError("Invalid provider") diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py new file mode 100644 index 0000000000..2d6e901447 --- /dev/null +++ b/api/services/auth/auth_type.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class AuthType(str, Enum): + FIRECRAWL = "firecrawl" + JINA = "jinareader" diff --git a/api/services/auth/firecrawl/__init__.py b/api/services/auth/firecrawl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/auth/firecrawl.py b/api/services/auth/firecrawl/firecrawl.py similarity index 100% rename from api/services/auth/firecrawl.py rename to api/services/auth/firecrawl/firecrawl.py diff --git a/api/services/auth/jina/__init__.py b/api/services/auth/jina/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/services/auth/jina.py b/api/services/auth/jina/jina.py similarity index 100% rename from api/services/auth/jina.py rename to api/services/auth/jina/jina.py diff --git a/web/app/(commonLayout)/datasets/template/template.en.mdx b/web/app/(commonLayout)/datasets/template/template.en.mdx index 99fab1ed8b..b846f6d9fb 100644 --- a/web/app/(commonLayout)/datasets/template/template.en.mdx +++ b/web/app/(commonLayout)/datasets/template/template.en.mdx @@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --- + + + + ### Path + + + Dataset ID + + + + ### Request Body + + + retrieval keywordc + + + retrieval keyword(Optional, if not filled, it will be recalled according to the default method) + - search_method (text) Search method: One of the following four keywords is required + - keyword_search Keyword search + - semantic_search Semantic search + - full_text_search Full-text search + - hybrid_search Hybrid search + - reranking_enable (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search + - reranking_mode (object) Rerank model configuration, optional, required if reranking is enabled + - reranking_provider_name (string) Rerank model provider + - reranking_model_name (string) Rerank model name + - weights (double) Semantic search weight setting in hybrid search mode + - top_k (integer) Number of results to return, optional + - score_threshold_enabled (bool) Whether to enable score threshold + - score_threshold (double) Score threshold + + + Unused field + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "query": "test", + "retrieval_model": { + "search_method": "keyword_search", + "reranking_enable": false, + "reranking_mode": null, + "reranking_model": { + "reranking_provider_name": "", + "reranking_model_name": "" + }, + "weights": null, + "top_k": 2, + "score_threshold_enabled": false, + "score_threshold": null + } + }' + ``` + + + ```json {{ title: 'Response' }} + { + "query": { + "content": "test" + }, + "records": [ + { + "segment": { + "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", + "position": 1, + "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "content": "Operation guide", + "answer": null, + "word_count": 847, + "tokens": 280, + "keywords": [ + "install", + "java", + "base", + "scripts", + "jdk", + "manual", + "internal", + "opens", + "add", + "vmoptions" + ], + "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", + "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", + "hit_count": 0, + "enabled": true, + "disabled_at": null, + "disabled_by": null, + "status": "completed", + "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", + "created_at": 1728734540, + "indexing_at": 1728734552, + "completed_at": 1728734584, + "error": null, + "stopped_at": null, + "document": { + "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "data_source_type": "upload_file", + "name": "readme.txt", + "doc_type": null + } + }, + "score": 3.730463140527718e-05, + "tsne_position": null + } + ] + } + ``` + + + + +--- + ### Error message diff --git a/web/app/(commonLayout)/datasets/template/template.zh.mdx b/web/app/(commonLayout)/datasets/template/template.zh.mdx index 2f8c7e99dc..ece4d3b771 100644 --- a/web/app/(commonLayout)/datasets/template/template.zh.mdx +++ b/web/app/(commonLayout)/datasets/template/template.zh.mdx @@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from +--- + + + + + ### Path + + + 知识库 ID + + + + ### Request Body + + + 召回关键词 + + + 召回参数(选填,如不填,按照默认方式召回) + - search_method (text) 检索方法:以下三个关键字之一,必填 + - keyword_search 关键字检索 + - semantic_search 语义检索 + - full_text_search 全文检索 + - hybrid_search 混合检索 + - reranking_enable (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值 + - reranking_mode (object) Rerank模型配置,非必填,如果启用了 reranking 则传值 + - reranking_provider_name (string) Rerank 模型提供商 + - reranking_model_name (string) Rerank 模型名称 + - weights (double) 混合检索模式下语意检索的权重设置 + - top_k (integer) 返回结果数量,非必填 + - score_threshold_enabled (bool) 是否开启Score阈值 + - score_threshold (double) Score阈值 + + + 未启用字段 + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "query": "test", + "retrieval_model": { + "search_method": "keyword_search", + "reranking_enable": false, + "reranking_mode": null, + "reranking_model": { + "reranking_provider_name": "", + "reranking_model_name": "" + }, + "weights": null, + "top_k": 2, + "score_threshold_enabled": false, + "score_threshold": null + } + }' + ``` + + + ```json {{ title: 'Response' }} + { + "query": { + "content": "test" + }, + "records": [ + { + "segment": { + "id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218", + "position": 1, + "document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "content": "Operation guide", + "answer": null, + "word_count": 847, + "tokens": 280, + "keywords": [ + "install", + "java", + "base", + "scripts", + "jdk", + "manual", + "internal", + "opens", + "add", + "vmoptions" + ], + "index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e", + "index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73", + "hit_count": 0, + "enabled": true, + "disabled_at": null, + "disabled_by": null, + "status": "completed", + "created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f", + "created_at": 1728734540, + "indexing_at": 1728734552, + "completed_at": 1728734584, + "error": null, + "stopped_at": null, + "document": { + "id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2", + "data_source_type": "upload_file", + "name": "readme.txt", + "doc_type": null + } + }, + "score": 3.730463140527718e-05, + "tsne_position": null + } + ] + } + ``` + + + + + --- diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index f556121518..04ae146645 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -63,7 +63,7 @@ const ConfigContent: FC = ({ } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const { - currentModel, + currentModel: currentRerankModel, } = useCurrentProviderAndModel( rerankModelList, rerankDefaultModel @@ -74,11 +74,6 @@ const ConfigContent: FC = ({ : undefined, ) - const handleDisabledSwitchClick = useCallback(() => { - if (!currentModel) - Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) - }, [currentModel, rerankDefaultModel, t]) - const rerankModel = (() => { if (datasetConfigs.reranking_model?.reranking_provider_name) { return { @@ -164,12 +159,33 @@ const ConfigContent: FC = ({ const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel + const canManuallyToggleRerank = useMemo(() => { + return !( + (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) + || selectedDatasetsMode.allExternal + ) + }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) + const showRerankModel = useMemo(() => { - if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic) + if (!canManuallyToggleRerank) return false - return true - }, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic]) + return datasetConfigs.reranking_enable + }, [canManuallyToggleRerank, datasetConfigs.reranking_enable]) + + const handleDisabledSwitchClick = useCallback(() => { + if (!currentRerankModel && !showRerankModel) + Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) + }, [currentRerankModel, showRerankModel, t]) + + useEffect(() => { + if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) { + onChange({ + ...datasetConfigs, + reranking_enable: showRerankModel, + }) + } + }, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange]) return (
@@ -256,13 +272,15 @@ const ConfigContent: FC = ({ > { - onChange({ - ...datasetConfigs, - reranking_enable: v, - }) + if (canManuallyToggleRerank) { + onChange({ + ...datasetConfigs, + reranking_enable: v, + }) + } }} />
diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 2d3df0b039..91d0e4e590 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -42,6 +42,7 @@ const ParamsConfig = ({ allHighQuality, allHighQualityFullTextSearch, allHighQualityVectorSearch, + allInternal, allExternal, mixtureHighQualityAndEconomic, inconsistentEmbeddingModel, @@ -50,7 +51,7 @@ const ParamsConfig = ({ const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs let rerankEnable = restConfigs.reranking_enable - if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal) + if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) rerankEnable = false if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) diff --git a/web/app/components/workflow/hooks/use-workflow-start-run.tsx b/web/app/components/workflow/hooks/use-workflow-start-run.tsx index 77e959b573..b2b1c69975 100644 --- a/web/app/components/workflow/hooks/use-workflow-start-run.tsx +++ b/web/app/components/workflow/hooks/use-workflow-start-run.tsx @@ -1,25 +1,17 @@ import { useCallback } from 'react' import { useStoreApi } from 'reactflow' -import { useTranslation } from 'react-i18next' import { useWorkflowStore } from '../store' import { BlockEnum, WorkflowRunningStatus, } from '../types' -import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' -import type { Node } from '../types' -import { useWorkflow } from './use-workflow' import { useIsChatMode, useNodesSyncDraft, useWorkflowInteractions, useWorkflowRun, } from './index' -import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useFeaturesStore } from '@/app/components/base/features/hooks' -import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default' -import Toast from '@/app/components/base/toast' export const useWorkflowStartRun = () => { const store = useStoreApi() @@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => { const isChatMode = useIsChatMode() const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() const { handleRun } = useWorkflowRun() - const { isFromStartNode } = useWorkflow() const { doSyncWorkflowDraft } = useNodesSyncDraft() - const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault - const { t } = useTranslation() - const { - modelList: rerankModelList, - defaultModel: rerankDefaultModel, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) - - const { - currentModel, - } = useCurrentProviderAndModel( - rerankModelList, - rerankDefaultModel - ? { - ...rerankDefaultModel, - provider: rerankDefaultModel.provider.provider, - } - : undefined, - ) const handleWorkflowStartRunInWorkflow = useCallback(async () => { const { @@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => { const { getNodes } = store.getState() const nodes = getNodes() const startNode = nodes.find(node => node.data.type === BlockEnum.Start) - const knowledgeRetrievalNodes = nodes.filter((node: Node) => - node.data.type === BlockEnum.KnowledgeRetrieval, - ) const startVariables = startNode?.data.variables || [] const fileSettings = featuresStore!.getState().features.file const { @@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => { setShowEnvPanel, } = workflowStore.getState() - if (knowledgeRetrievalNodes.length > 0) { - for (const node of knowledgeRetrievalNodes) { - if (isFromStartNode(node.id)) { - const res = checkKnowledgeRetrievalValid(node.data, t) - if (!res.isValid || !currentModel || !rerankDefaultModel) { - const errorMessage = res.errorMessage - if (errorMessage) { - Toast.notify({ - type: 'error', - message: errorMessage, - }) - return false - } - else { - Toast.notify({ - type: 'error', - message: t('appDebug.datasetConfig.rerankModelRequired'), - }) - return false - } - } - } - } - } - setShowEnvPanel(false) if (showDebugAndPreviewPanel) { diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index e83a5d97b5..01c1e31ccc 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets' import { fetchDatasets } from '@/service/datasets' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { @@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const startNodeId = startNode?.id const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) + const inputRef = useRef(inputs) + const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { const newInputs = produce(s, (draft) => { if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) @@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { }) // not work in pass to draft... doSetInputs(newInputs) + inputRef.current = newInputs }, [doSetInputs]) - const inputRef = useRef(inputs) - useEffect(() => { - inputRef.current = inputs - }, [inputs]) - const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { const newInputs = produce(inputs, (draft) => { draft.query_variable_selector = newVar as ValueSelector @@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const { + modelList: rerankModelList, defaultModel: rerankDefaultModel, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { + currentModel: currentRerankModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) + const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { const newInputs = produce(inputRef.current, (draft) => { if (!draft.single_retrieval_config) { @@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { // set defaults models useEffect(() => { const inputs = inputRef.current - if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider) + if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel) return if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) @@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } } } - const multipleRetrievalConfig = draft.multiple_retrieval_config draft.multiple_retrieval_config = { top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, @@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { reranking_model: multipleRetrievalConfig?.reranking_model, reranking_mode: multipleRetrievalConfig?.reranking_mode, weights: multipleRetrievalConfig?.weights, + reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined + ? multipleRetrievalConfig.reranking_enable + : Boolean(currentRerankModel && rerankDefaultModel), } }) setInputs(newInput) @@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { }, []) useEffect(() => { + const inputs = inputRef.current let query_variable_selector: ValueSelector = inputs.query_variable_selector if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) query_variable_selector = [startNodeId, 'sys.query'] - setInputs({ - ...inputs, - query_variable_selector, - }) + setInputs(produce(inputs, (draft) => { + draft.query_variable_selector = query_variable_selector + })) // eslint-disable-next-line react-hooks/exhaustive-deps }, []) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index 85ae6c4c96..e48777d948 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr reranking_mode, reranking_model, weights, - reranking_enable: allEconomic ? reranking_enable : true, + reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, } if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)