Merge branch 'fix/rerank-validation-in-run' into deploy/dev

This commit is contained in:
Yi 2024-10-16 16:47:49 +08:00
commit 03d704ea5a
39 changed files with 777 additions and 249 deletions

View File

@ -27,18 +27,17 @@ jobs:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python ${{ matrix.python-version }} - name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: 'poetry'
cache-dependency-path: | cache-dependency-path: |
api/pyproject.toml api/pyproject.toml
api/poetry.lock api/poetry.lock
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Check Poetry lockfile - name: Check Poetry lockfile
run: | run: |
poetry check -C api --lock poetry check -C api --lock

View File

@ -24,15 +24,15 @@ jobs:
with: with:
files: api/** files: api/**
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
with: with:
python-version: '3.10' python-version: '3.10'
- name: Install Poetry
uses: abatilo/actions-poetry@v3
- name: Python dependencies - name: Python dependencies
if: steps.changed-files.outputs.any_changed == 'true' if: steps.changed-files.outputs.any_changed == 'true'
run: poetry install -C api --only lint run: poetry install -C api --only lint

View File

@ -85,3 +85,4 @@
cd ../ cd ../
poetry run -C api bash dev/pytest/pytest_all_tests.sh poetry run -C api bash dev/pytest/pytest_all_tests.sh
``` ```

View File

@ -14,7 +14,7 @@ class OracleConfig(BaseSettings):
default=None, default=None,
) )
ORACLE_PORT: Optional[PositiveInt] = Field( ORACLE_PORT: PositiveInt = Field(
description="Port number on which the Oracle database server is listening (default is 1521)", description="Port number on which the Oracle database server is listening (default is 1521)",
default=1521, default=1521,
) )

View File

@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings):
default=None, default=None,
) )
PGVECTOR_PORT: Optional[PositiveInt] = Field( PGVECTOR_PORT: PositiveInt = Field(
description="Port number on which the PostgreSQL server is listening (default is 5433)", description="Port number on which the PostgreSQL server is listening (default is 5433)",
default=5433, default=5433,
) )

View File

@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings):
default=None, default=None,
) )
PGVECTO_RS_PORT: Optional[PositiveInt] = Field( PGVECTO_RS_PORT: PositiveInt = Field(
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)", description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
default=5431, default=5431,
) )

View File

@ -11,27 +11,39 @@ class VikingDBConfig(BaseModel):
""" """
VIKINGDB_ACCESS_KEY: Optional[str] = Field( VIKINGDB_ACCESS_KEY: Optional[str] = Field(
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication." description="The Access Key provided by Volcengine VikingDB for API authentication."
"Refer to the following documentation for details on obtaining credentials:"
"https://www.volcengine.com/docs/6291/65568",
default=None,
) )
VIKINGDB_SECRET_KEY: Optional[str] = Field( VIKINGDB_SECRET_KEY: Optional[str] = Field(
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication." description="The Secret Key provided by Volcengine VikingDB for API authentication.",
default=None,
) )
VIKINGDB_REGION: Optional[str] = Field(
default="cn-shanghai", VIKINGDB_REGION: str = Field(
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').", description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
default="cn-shanghai",
) )
VIKINGDB_HOST: Optional[str] = Field(
default="api-vikingdb.mlp.cn-shanghai.volces.com", VIKINGDB_HOST: str = Field(
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \ description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
'api-vikingdb.mlp.cn-shanghai.volces.com')", 'api-vikingdb.mlp.cn-shanghai.volces.com')",
default="api-vikingdb.mlp.cn-shanghai.volces.com",
) )
VIKINGDB_SCHEME: Optional[str] = Field(
default="http", VIKINGDB_SCHEME: str = Field(
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').", description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
default="http",
) )
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
default=30, description="The connection timeout of the Volcengine VikingDB service." VIKINGDB_CONNECTION_TIMEOUT: int = Field(
description="The connection timeout of the Volcengine VikingDB service.",
default=30,
) )
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
default=30, description="The socket timeout of the Volcengine VikingDB service." VIKINGDB_SOCKET_TIMEOUT: int = Field(
description="The socket timeout of the Volcengine VikingDB service.",
default=30,
) )

View File

@ -1,88 +1,24 @@
import logging from flask_restful import Resource
from flask_login import current_user
from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ( from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from controllers.console.setup import setup_required from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required from controllers.console.wraps import account_initialization_required
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from libs.login import login_required from libs.login import login_required
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
class HitTestingApi(Resource): class HitTestingApi(Resource, DatasetsHitTestingBase):
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, dataset_id): def post(self, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = self.get_and_validate_dataset(dataset_id_str)
if dataset is None: args = self.parse_args()
raise NotFound("Dataset not found.") self.hit_testing_args_check(args)
try: return self.perform_hit_testing(dataset, args)
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing") api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

View File

@ -0,0 +1,85 @@
import logging
from flask_login import current_user
from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service
from controllers.console.app.error import (
CompletionRequestError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from controllers.console.datasets.error import DatasetNotInitializedError
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from fields.hit_testing_fields import hit_testing_record_fields
from services.dataset_service import DatasetService
from services.hit_testing_service import HitTestingService
class DatasetsHitTestingBase:
@staticmethod
def get_and_validate_dataset(dataset_id: str):
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
return dataset
@staticmethod
def hit_testing_args_check(args):
HitTestingService.hit_testing_args_check(args)
@staticmethod
def parse_args():
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
return parser.parse_args()
@staticmethod
def perform_hit_testing(dataset, args):
try:
response = HitTestingService.retrieve(
dataset=dataset,
query=args["query"],
account=current_user,
retrieval_model=args["retrieval_model"],
external_retrieval_model=args["external_retrieval_model"],
limit=10,
)
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
except services.errors.index.IndexNotInitializedError:
raise DatasetNotInitializedError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model or Reranking Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:
logging.exception("Hit testing failed.")
raise InternalServerError(str(e))

View File

@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
bp = Blueprint("service_api", __name__, url_prefix="/v1") bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp) api = ExternalApi(bp)
from . import index from . import index
from .app import app, audio, completion, conversation, file, message, workflow from .app import app, audio, completion, conversation, file, message, workflow
from .dataset import dataset, document, segment from .dataset import dataset, document, hit_testing, segment

View File

@ -4,7 +4,6 @@ from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound from werkzeug.exceptions import InternalServerError, NotFound
import services import services
from constants import UUID_NIL
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
AppUnavailableError, AppUnavailableError,
@ -108,7 +107,6 @@ class ChatApi(Resource):
parser.add_argument("conversation_id", type=uuid_value, location="json") parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json")
args = parser.parse_args() args = parser.parse_args()

View File

@ -0,0 +1,17 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
dataset = self.get_and_validate_dataset(dataset_id_str)
args = self.parse_args()
self.hit_testing_args_check(args)
return self.perform_hit_testing(dataset, args)
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")

View File

@ -10,6 +10,7 @@ from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
import contexts import contexts
from constants import UUID_NIL
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
@ -132,7 +133,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id"), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,

View File

@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
@ -140,7 +141,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id"), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id, user_id=user.id,
stream=stream, stream=stream,
invoke_from=invoke_from, invoke_from=invoke_from,

View File

@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app from flask import Flask, current_app
from pydantic import ValidationError from pydantic import ValidationError
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
@ -138,7 +139,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role), else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
query=query, query=query,
files=file_objs, files=file_objs,
parent_message_id=args.get("parent_message_id"), parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
user_id=user.id, user_id=user.id,
invoke_from=invoke_from, invoke_from=invoke_from,
extras=extras, extras=extras,

View File

@ -2,8 +2,9 @@ from collections.abc import Mapping, Sequence
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
from constants import UUID_NIL
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
from core.entities.provider_configuration import ProviderModelBundle from core.entities.provider_configuration import ProviderModelBundle
from core.file.models import File from core.file.models import File
@ -116,13 +117,36 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): class ConversationAppGenerateEntity(AppGenerateEntity):
"""
Base entity for conversation-based app generation.
"""
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = Field(
default=None,
description=(
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
"For service API, we need to ensure its forward compatibility, "
"so passing in the parent_message_id as request arg is not supported for now. "
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
),
)
@field_validator("parent_message_id")
@classmethod
def validate_parent_message_id(cls, v, info: ValidationInfo):
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
raise ValueError("parent_message_id should be UUID_NIL for service API")
return v
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
""" """
Chat Application Generate Entity. Chat Application Generate Entity.
""" """
conversation_id: Optional[str] = None pass
parent_message_id: Optional[str] = None
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity): class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
@ -133,16 +157,15 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
pass pass
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity): class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
""" """
Agent Chat Application Generate Entity. Agent Chat Application Generate Entity.
""" """
conversation_id: Optional[str] = None pass
parent_message_id: Optional[str] = None
class AdvancedChatAppGenerateEntity(AppGenerateEntity): class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
""" """
Advanced Chat Application Generate Entity. Advanced Chat Application Generate Entity.
""" """
@ -150,8 +173,6 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
# app config # app config
app_config: WorkflowUIBasedAppConfig app_config: WorkflowUIBasedAppConfig
conversation_id: Optional[str] = None
parent_message_id: Optional[str] = None
workflow_run_id: Optional[str] = None workflow_run_id: Optional[str] = None
query: str query: str

View File

@ -18,6 +18,7 @@ supported_model_types:
- text-embedding - text-embedding
configurate_methods: configurate_methods:
- predefined-model - predefined-model
- customizable-model
provider_credential_schema: provider_credential_schema:
credential_form_schemas: credential_form_schemas:
- variable: fireworks_api_key - variable: fireworks_api_key
@ -28,3 +29,75 @@ provider_credential_schema:
placeholder: placeholder:
zh_Hans: 在此输入您的 API Key zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model URL
zh_Hans: 模型URL
placeholder:
en_US: Enter your Model URL
zh_Hans: 输入模型URL
credential_form_schemas:
- variable: model_label_zh_Hanns
label:
zh_Hans: 模型中文名称
en_US: The zh_Hans of Model
required: true
type: text-input
placeholder:
zh_Hans: 在此输入您的模型中文名称
en_US: Enter your zh_Hans of Model
- variable: model_label_en_US
label:
zh_Hans: 模型英文名称
en_US: The en_US of Model
required: true
type: text-input
placeholder:
zh_Hans: 在此输入您的模型英文名称
en_US: Enter your en_US of Model
- variable: fireworks_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '4096'
type: text-input
show_on:
- variable: __model_type
value: llm
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- value: function_call
label:
en_US: Support
zh_Hans: 支持
show_on:
- variable: __model_type
value: llm

View File

@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
AssistantPromptMessage, AssistantPromptMessage,
ImagePromptMessageContent, ImagePromptMessageContent,
@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage, ToolPromptMessage,
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
num_tokens += self._get_num_tokens_by_gpt2(required_field) num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens return num_tokens
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
return AIModelEntity(
model=model,
label=I18nObject(
en_US=credentials.get("model_label_en_US", model),
zh_Hans=credentials.get("model_label_zh_Hanns", model),
),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "function_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get("max_tokens", 4096)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="top_k",
use_template="top_k",
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
type=ParameterType.FLOAT,
),
],
)

View File

@ -0,0 +1,46 @@
model: accounts/fireworks/models/qwen2p5-72b-instruct
label:
zh_Hans: Qwen2.5 72B Instruct
en_US: Qwen2.5 72B Instruct
model_type: llm
features:
- agent-thought
- tool-call
model_properties:
mode: chat
context_size: 32768
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
- name: max_tokens
use_template: max_tokens
- name: context_length_exceeded_behavior
default: None
label:
zh_Hans: 上下文长度超出行为
en_US: Context Length Exceeded Behavior
help:
zh_Hans: 上下文长度超出行为
en_US: Context Length Exceeded Behavior
type: string
options:
- None
- truncate
- error
- name: response_format
use_template: response_format
pricing:
input: '0.9'
output: '0.9'
unit: '0.000001'
currency: USD

View File

@ -61,7 +61,8 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
url = f"{self.api_base}?GroupId={group_id}" url = f"{self.api_base}?GroupId={group_id}"
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"} headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
data = {"model": "embo-01", "texts": texts, "type": "db"} embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query"
data = {"model": "embo-01", "texts": texts, "type": embedding_type}
try: try:
response = post(url, headers=headers, data=dumps(data)) response = post(url, headers=headers, data=dumps(data))

View File

@ -1,8 +1,8 @@
from typing import Any from typing import Any
from configs import dify_config from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba import Jieba
from core.rag.datasource.keyword.keyword_base import BaseKeyword from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.datasource.keyword.keyword_type import KeyWordType
from core.rag.models.document import Document from core.rag.models.document import Document
from models.dataset import Dataset from models.dataset import Dataset
@ -13,16 +13,19 @@ class Keyword:
self._keyword_processor = self._init_keyword() self._keyword_processor = self._init_keyword()
def _init_keyword(self) -> BaseKeyword: def _init_keyword(self) -> BaseKeyword:
config = dify_config keyword_type = dify_config.KEYWORD_STORE
keyword_type = config.KEYWORD_STORE keyword_factory = self.get_keyword_factory(keyword_type)
return keyword_factory(self._dataset)
if not keyword_type: @staticmethod
raise ValueError("Keyword store must be specified.") def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]:
match keyword_type:
case KeyWordType.JIEBA:
from core.rag.datasource.keyword.jieba.jieba import Jieba
if keyword_type == "jieba": return Jieba
return Jieba(dataset=self._dataset) case _:
else: raise ValueError(f"Keyword store {keyword_type} is not supported.")
raise ValueError(f"Keyword store {keyword_type} is not supported.")
def create(self, texts: list[Document], **kwargs): def create(self, texts: list[Document], **kwargs):
self._keyword_processor.create(texts, **kwargs) self._keyword_processor.create(texts, **kwargs)

View File

@ -0,0 +1,5 @@
from enum import Enum
class KeyWordType(str, Enum):
JIEBA = "jieba"

View File

@ -1,5 +1,4 @@
from collections.abc import Generator from collections.abc import Generator
from contextlib import closing
import oss2 as aliyun_s3 import oss2 as aliyun_s3
from flask import Flask from flask import Flask
@ -34,15 +33,15 @@ class AliyunOssStorage(BaseStorage):
self.client.put_object(self.__wrapper_folder_filename(filename), data) self.client.put_object(self.__wrapper_folder_filename(filename), data)
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: obj = self.client.get_object(self.__wrapper_folder_filename(filename))
data = obj.read() data = obj.read()
return data return data
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator: def generate(filename: str = filename) -> Generator:
with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj: obj = self.client.get_object(self.__wrapper_folder_filename(filename))
while chunk := obj.read(4096): while chunk := obj.read(4096):
yield chunk yield chunk
return generate() return generate()

View File

@ -1,6 +1,5 @@
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from contextlib import closing
import boto3 import boto3
from botocore.client import Config from botocore.client import Config
@ -55,8 +54,7 @@ class AwsS3Storage(BaseStorage):
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
try: try:
with closing(self.client) as client: data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex: except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey": if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
@ -67,9 +65,8 @@ class AwsS3Storage(BaseStorage):
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator: def generate(filename: str = filename) -> Generator:
try: try:
with closing(self.client) as client: response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response = client.get_object(Bucket=self.bucket_name, Key=filename) yield from response["Body"].iter_chunks()
yield from response["Body"].iter_chunks()
except ClientError as ex: except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey": if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
@ -79,16 +76,14 @@ class AwsS3Storage(BaseStorage):
return generate() return generate()
def download(self, filename, target_filepath): def download(self, filename, target_filepath):
with closing(self.client) as client: self.client.download_file(self.bucket_name, filename, target_filepath)
client.download_file(self.bucket_name, filename, target_filepath)
def exists(self, filename): def exists(self, filename):
with closing(self.client) as client: try:
try: self.client.head_object(Bucket=self.bucket_name, Key=filename)
client.head_object(Bucket=self.bucket_name, Key=filename) return True
return True except:
except: return False
return False
def delete(self, filename): def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename) self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -2,7 +2,6 @@ import base64
import io import io
import json import json
from collections.abc import Generator from collections.abc import Generator
from contextlib import closing
from flask import Flask from flask import Flask
from google.cloud import storage as google_cloud_storage from google.cloud import storage as google_cloud_storage
@ -43,7 +42,7 @@ class GoogleCloudStorage(BaseStorage):
def generate(filename: str = filename) -> Generator: def generate(filename: str = filename) -> Generator:
bucket = self.client.get_bucket(self.bucket_name) bucket = self.client.get_bucket(self.bucket_name)
blob = bucket.get_blob(filename) blob = bucket.get_blob(filename)
with closing(blob.open(mode="rb")) as blob_stream: with blob.open(mode="rb") as blob_stream:
while chunk := blob_stream.read(4096): while chunk := blob_stream.read(4096):
yield chunk yield chunk

View File

@ -1,5 +1,4 @@
from collections.abc import Generator from collections.abc import Generator
from contextlib import closing
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
@ -28,8 +27,7 @@ class OracleOCIStorage(BaseStorage):
def load_once(self, filename: str) -> bytes: def load_once(self, filename: str) -> bytes:
try: try:
with closing(self.client) as client: data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex: except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey": if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
@ -40,9 +38,8 @@ class OracleOCIStorage(BaseStorage):
def load_stream(self, filename: str) -> Generator: def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator: def generate(filename: str = filename) -> Generator:
try: try:
with closing(self.client) as client: response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
response = client.get_object(Bucket=self.bucket_name, Key=filename) yield from response["Body"].iter_chunks()
yield from response["Body"].iter_chunks()
except ClientError as ex: except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey": if ex.response["Error"]["Code"] == "NoSuchKey":
raise FileNotFoundError("File not found") raise FileNotFoundError("File not found")
@ -52,16 +49,14 @@ class OracleOCIStorage(BaseStorage):
return generate() return generate()
def download(self, filename, target_filepath): def download(self, filename, target_filepath):
with closing(self.client) as client: self.client.download_file(self.bucket_name, filename, target_filepath)
client.download_file(self.bucket_name, filename, target_filepath)
def exists(self, filename): def exists(self, filename):
with closing(self.client) as client: try:
try: self.client.head_object(Bucket=self.bucket_name, Key=filename)
client.head_object(Bucket=self.bucket_name, Key=filename) return True
return True except:
except: return False
return False
def delete(self, filename): def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename) self.client.delete_object(Bucket=self.bucket_name, Key=filename)

View File

@ -1,15 +1,25 @@
from services.auth.firecrawl import FirecrawlAuth from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.jina import JinaAuth from services.auth.auth_type import AuthType
class ApiKeyAuthFactory: class ApiKeyAuthFactory:
def __init__(self, provider: str, credentials: dict): def __init__(self, provider: str, credentials: dict):
if provider == "firecrawl": auth_factory = self.get_apikey_auth_factory(provider)
self.auth = FirecrawlAuth(credentials) self.auth = auth_factory(credentials)
elif provider == "jinareader":
self.auth = JinaAuth(credentials)
else:
raise ValueError("Invalid provider")
def validate_credentials(self): def validate_credentials(self):
return self.auth.validate_credentials() return self.auth.validate_credentials()
@staticmethod
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
match provider:
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth
return FirecrawlAuth
case AuthType.JINA:
from services.auth.jina.jina import JinaAuth
return JinaAuth
case _:
raise ValueError("Invalid provider")

View File

@ -0,0 +1,6 @@
from enum import Enum
class AuthType(str, Enum):
FIRECRAWL = "firecrawl"
JINA = "jinareader"

View File

View File

View File

@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
--- ---
<Heading
url='/datasets/{dataset_id}/hit_testing'
method='POST'
title='Dataset hit testing'
name='#dataset_hit_testing'
/>
<Row>
<Col>
### Path
<Properties>
<Property name='dataset_id' type='string' key='dataset_id'>
Dataset ID
</Property>
</Properties>
### Request Body
<Properties>
<Property name='query' type='string' key='query'>
retrieval keywordc
</Property>
<Property name='retrieval_model' type='object' key='retrieval_model'>
retrieval keyword(Optional, if not filled, it will be recalled according to the default method)
- <code>search_method</code> (text) Search method: One of the following four keywords is required
- <code>keyword_search</code> Keyword search
- <code>semantic_search</code> Semantic search
- <code>full_text_search</code> Full-text search
- <code>hybrid_search</code> Hybrid search
- <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search
- <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled
- <code>reranking_provider_name</code> (string) Rerank model provider
- <code>reranking_model_name</code> (string) Rerank model name
- <code>weights</code> (double) Semantic search weight setting in hybrid search mode
- <code>top_k</code> (integer) Number of results to return, optional
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
- <code>score_threshold</code> (double) Score threshold
</Property>
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
Unused field
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
label="/datasets/{dataset_id}/hit_testing"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
}
}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 2,
"score_threshold_enabled": false,
"score_threshold": null
}
}'
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"query": {
"content": "test"
},
"records": [
{
"segment": {
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
"position": 1,
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"content": "Operation guide",
"answer": null,
"word_count": 847,
"tokens": 280,
"keywords": [
"install",
"java",
"base",
"scripts",
"jdk",
"manual",
"internal",
"opens",
"add",
"vmoptions"
],
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
"hit_count": 0,
"enabled": true,
"disabled_at": null,
"disabled_by": null,
"status": "completed",
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
"created_at": 1728734540,
"indexing_at": 1728734552,
"completed_at": 1728734584,
"error": null,
"stopped_at": null,
"document": {
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"data_source_type": "upload_file",
"name": "readme.txt",
"doc_type": null
}
},
"score": 3.730463140527718e-05,
"tsne_position": null
}
]
}
```
</CodeGroup>
</Col>
</Row>
---
<Row> <Row>
<Col> <Col>
### Error message ### Error message

View File

@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Col> </Col>
</Row> </Row>
---
<Heading
url='/datasets/{dataset_id}/hit_testing'
method='POST'
title='知识库召回测试'
name='#dataset_hit_testing'
/>
<Row>
<Col>
### Path
<Properties>
<Property name='dataset_id' type='string' key='dataset_id'>
知识库 ID
</Property>
</Properties>
### Request Body
<Properties>
<Property name='query' type='string' key='query'>
召回关键词
</Property>
<Property name='retrieval_model' type='object' key='retrieval_model'>
召回参数(选填,如不填,按照默认方式召回)
- <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填
- <code>keyword_search</code> 关键字检索
- <code>semantic_search</code> 语义检索
- <code>full_text_search</code> 全文检索
- <code>hybrid_search</code> 混合检索
- <code>reranking_enable</code> (bool) 是否启用 Reranking非必填如果检索模式为semantic_search模式或者hybrid_search则传值
- <code>reranking_mode</code> (object) Rerank模型配置非必填如果启用了 reranking 则传值
- <code>reranking_provider_name</code> (string) Rerank 模型提供商
- <code>reranking_model_name</code> (string) Rerank 模型名称
- <code>weights</code> (double) 混合检索模式下语意检索的权重设置
- <code>top_k</code> (integer) 返回结果数量,非必填
- <code>score_threshold_enabled</code> (bool) 是否开启Score阈值
- <code>score_threshold</code> (double) Score阈值
</Property>
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
未启用字段
</Property>
</Properties>
</Col>
<Col sticky>
<CodeGroup
title="Request"
tag="POST"
label="/datasets/{dataset_id}/hit_testing"
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 1,
"score_threshold_enabled": false,
"score_threshold": null
}
}'`}
>
```bash {{ title: 'cURL' }}
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
--header 'Authorization: Bearer {api_key}' \
--header 'Content-Type: application/json' \
--data-raw '{
"query": "test",
"retrieval_model": {
"search_method": "keyword_search",
"reranking_enable": false,
"reranking_mode": null,
"reranking_model": {
"reranking_provider_name": "",
"reranking_model_name": ""
},
"weights": null,
"top_k": 2,
"score_threshold_enabled": false,
"score_threshold": null
}
}'
```
</CodeGroup>
<CodeGroup title="Response">
```json {{ title: 'Response' }}
{
"query": {
"content": "test"
},
"records": [
{
"segment": {
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
"position": 1,
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"content": "Operation guide",
"answer": null,
"word_count": 847,
"tokens": 280,
"keywords": [
"install",
"java",
"base",
"scripts",
"jdk",
"manual",
"internal",
"opens",
"add",
"vmoptions"
],
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
"hit_count": 0,
"enabled": true,
"disabled_at": null,
"disabled_by": null,
"status": "completed",
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
"created_at": 1728734540,
"indexing_at": 1728734552,
"completed_at": 1728734584,
"error": null,
"stopped_at": null,
"document": {
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
"data_source_type": "upload_file",
"name": "readme.txt",
"doc_type": null
}
},
"score": 3.730463140527718e-05,
"tsne_position": null
}
]
}
```
</CodeGroup>
</Col>
</Row>
--- ---
<Row> <Row>

View File

@ -63,7 +63,7 @@ const ConfigContent: FC<Props> = ({
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const { const {
currentModel, currentModel: currentRerankModel,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
rerankDefaultModel rerankDefaultModel
@ -74,11 +74,6 @@ const ConfigContent: FC<Props> = ({
: undefined, : undefined,
) )
const handleDisabledSwitchClick = useCallback(() => {
if (!currentModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentModel, rerankDefaultModel, t])
const rerankModel = (() => { const rerankModel = (() => {
if (datasetConfigs.reranking_model?.reranking_provider_name) { if (datasetConfigs.reranking_model?.reranking_provider_name) {
return { return {
@ -164,12 +159,33 @@ const ConfigContent: FC<Props> = ({
const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
const canManuallyToggleRerank = useMemo(() => {
return !(
(selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|| selectedDatasetsMode.allExternal
)
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
const showRerankModel = useMemo(() => { const showRerankModel = useMemo(() => {
if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic) if (!canManuallyToggleRerank)
return false return false
return true return datasetConfigs.reranking_enable
}, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic]) }, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
const handleDisabledSwitchClick = useCallback(() => {
if (!currentRerankModel && !showRerankModel)
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
}, [currentRerankModel, showRerankModel, t])
useEffect(() => {
if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
onChange({
...datasetConfigs,
reranking_enable: showRerankModel,
})
}
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
return ( return (
<div> <div>
@ -256,13 +272,15 @@ const ConfigContent: FC<Props> = ({
> >
<Switch <Switch
size='md' size='md'
defaultValue={currentModel ? showRerankModel : false} defaultValue={showRerankModel}
disabled={!currentModel} disabled={!currentRerankModel || !canManuallyToggleRerank}
onChange={(v) => { onChange={(v) => {
onChange({ if (canManuallyToggleRerank) {
...datasetConfigs, onChange({
reranking_enable: v, ...datasetConfigs,
}) reranking_enable: v,
})
}
}} }}
/> />
</div> </div>

View File

@ -42,6 +42,7 @@ const ParamsConfig = ({
allHighQuality, allHighQuality,
allHighQualityFullTextSearch, allHighQualityFullTextSearch,
allHighQualityVectorSearch, allHighQualityVectorSearch,
allInternal,
allExternal, allExternal,
mixtureHighQualityAndEconomic, mixtureHighQualityAndEconomic,
inconsistentEmbeddingModel, inconsistentEmbeddingModel,
@ -50,7 +51,7 @@ const ParamsConfig = ({
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
let rerankEnable = restConfigs.reranking_enable let rerankEnable = restConfigs.reranking_enable
if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal) if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
rerankEnable = false rerankEnable = false
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))

View File

@ -1,25 +1,17 @@
import { useCallback } from 'react' import { useCallback } from 'react'
import { useStoreApi } from 'reactflow' import { useStoreApi } from 'reactflow'
import { useTranslation } from 'react-i18next'
import { useWorkflowStore } from '../store' import { useWorkflowStore } from '../store'
import { import {
BlockEnum, BlockEnum,
WorkflowRunningStatus, WorkflowRunningStatus,
} from '../types' } from '../types'
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
import type { Node } from '../types'
import { useWorkflow } from './use-workflow'
import { import {
useIsChatMode, useIsChatMode,
useNodesSyncDraft, useNodesSyncDraft,
useWorkflowInteractions, useWorkflowInteractions,
useWorkflowRun, useWorkflowRun,
} from './index' } from './index'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { useFeaturesStore } from '@/app/components/base/features/hooks' import { useFeaturesStore } from '@/app/components/base/features/hooks'
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
import Toast from '@/app/components/base/toast'
export const useWorkflowStartRun = () => { export const useWorkflowStartRun = () => {
const store = useStoreApi() const store = useStoreApi()
@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => {
const isChatMode = useIsChatMode() const isChatMode = useIsChatMode()
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
const { handleRun } = useWorkflowRun() const { handleRun } = useWorkflowRun()
const { isFromStartNode } = useWorkflow()
const { doSyncWorkflowDraft } = useNodesSyncDraft() const { doSyncWorkflowDraft } = useNodesSyncDraft()
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
const { t } = useTranslation()
const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleWorkflowStartRunInWorkflow = useCallback(async () => { const handleWorkflowStartRunInWorkflow = useCallback(async () => {
const { const {
@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => {
const { getNodes } = store.getState() const { getNodes } = store.getState()
const nodes = getNodes() const nodes = getNodes()
const startNode = nodes.find(node => node.data.type === BlockEnum.Start) const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
node.data.type === BlockEnum.KnowledgeRetrieval,
)
const startVariables = startNode?.data.variables || [] const startVariables = startNode?.data.variables || []
const fileSettings = featuresStore!.getState().features.file const fileSettings = featuresStore!.getState().features.file
const { const {
@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => {
setShowEnvPanel, setShowEnvPanel,
} = workflowStore.getState() } = workflowStore.getState()
if (knowledgeRetrievalNodes.length > 0) {
for (const node of knowledgeRetrievalNodes) {
if (isFromStartNode(node.id)) {
const res = checkKnowledgeRetrievalValid(node.data, t)
if (!res.isValid || !currentModel || !rerankDefaultModel) {
const errorMessage = res.errorMessage
if (errorMessage) {
Toast.notify({
type: 'error',
message: errorMessage,
})
return false
}
else {
Toast.notify({
type: 'error',
message: t('appDebug.datasetConfig.rerankModelRequired'),
})
return false
}
}
}
}
}
setShowEnvPanel(false) setShowEnvPanel(false)
if (showDebugAndPreviewPanel) { if (showDebugAndPreviewPanel) {

View File

@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets' import { fetchDatasets } from '@/service/datasets'
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const startNodeId = startNode?.id const startNodeId = startNode?.id
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload) const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
const inputRef = useRef(inputs)
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
const newInputs = produce(s, (draft) => { const newInputs = produce(s, (draft) => {
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}) })
// not work in pass to draft... // not work in pass to draft...
doSetInputs(newInputs) doSetInputs(newInputs)
inputRef.current = newInputs
}, [doSetInputs]) }, [doSetInputs])
const inputRef = useRef(inputs)
useEffect(() => {
inputRef.current = inputs
}, [inputs])
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.query_variable_selector = newVar as ValueSelector draft.query_variable_selector = newVar as ValueSelector
@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
const { const {
modelList: rerankModelList,
defaultModel: rerankDefaultModel, defaultModel: rerankDefaultModel,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const {
currentModel: currentRerankModel,
} = useCurrentProviderAndModel(
rerankModelList,
rerankDefaultModel
? {
...rerankDefaultModel,
provider: rerankDefaultModel.provider.provider,
}
: undefined,
)
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
const newInputs = produce(inputRef.current, (draft) => { const newInputs = produce(inputRef.current, (draft) => {
if (!draft.single_retrieval_config) { if (!draft.single_retrieval_config) {
@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
// set defaults models // set defaults models
useEffect(() => { useEffect(() => {
const inputs = inputRef.current const inputs = inputRef.current
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider) if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
return return
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} }
} }
} }
const multipleRetrievalConfig = draft.multiple_retrieval_config const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = { draft.multiple_retrieval_config = {
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
reranking_model: multipleRetrievalConfig?.reranking_model, reranking_model: multipleRetrievalConfig?.reranking_model,
reranking_mode: multipleRetrievalConfig?.reranking_mode, reranking_mode: multipleRetrievalConfig?.reranking_mode,
weights: multipleRetrievalConfig?.weights, weights: multipleRetrievalConfig?.weights,
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
? multipleRetrievalConfig.reranking_enable
: Boolean(currentRerankModel && rerankDefaultModel),
} }
}) })
setInputs(newInput) setInputs(newInput)
@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
}, []) }, [])
useEffect(() => { useEffect(() => {
const inputs = inputRef.current
let query_variable_selector: ValueSelector = inputs.query_variable_selector let query_variable_selector: ValueSelector = inputs.query_variable_selector
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
query_variable_selector = [startNodeId, 'sys.query'] query_variable_selector = [startNodeId, 'sys.query']
setInputs({ setInputs(produce(inputs, (draft) => {
...inputs, draft.query_variable_selector = query_variable_selector
query_variable_selector, }))
})
// eslint-disable-next-line react-hooks/exhaustive-deps // eslint-disable-next-line react-hooks/exhaustive-deps
}, []) }, [])

View File

@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
reranking_mode, reranking_mode,
reranking_model, reranking_model,
weights, weights,
reranking_enable: allEconomic ? reranking_enable : true, reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
} }
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)