mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Merge branch 'fix/rerank-validation-in-run' into deploy/dev
This commit is contained in:
commit
03d704ea5a
7
.github/workflows/api-tests.yml
vendored
7
.github/workflows/api-tests.yml
vendored
|
@ -27,18 +27,17 @@ jobs:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: abatilo/actions-poetry@v3
|
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: 'poetry'
|
|
||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
api/pyproject.toml
|
api/pyproject.toml
|
||||||
api/poetry.lock
|
api/poetry.lock
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: abatilo/actions-poetry@v3
|
||||||
|
|
||||||
- name: Check Poetry lockfile
|
- name: Check Poetry lockfile
|
||||||
run: |
|
run: |
|
||||||
poetry check -C api --lock
|
poetry check -C api --lock
|
||||||
|
|
6
.github/workflows/style.yml
vendored
6
.github/workflows/style.yml
vendored
|
@ -24,15 +24,15 @@ jobs:
|
||||||
with:
|
with:
|
||||||
files: api/**
|
files: api/**
|
||||||
|
|
||||||
- name: Install Poetry
|
|
||||||
uses: abatilo/actions-poetry@v3
|
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install Poetry
|
||||||
|
uses: abatilo/actions-poetry@v3
|
||||||
|
|
||||||
- name: Python dependencies
|
- name: Python dependencies
|
||||||
if: steps.changed-files.outputs.any_changed == 'true'
|
if: steps.changed-files.outputs.any_changed == 'true'
|
||||||
run: poetry install -C api --only lint
|
run: poetry install -C api --only lint
|
||||||
|
|
|
@ -85,3 +85,4 @@
|
||||||
cd ../
|
cd ../
|
||||||
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
poetry run -C api bash dev/pytest/pytest_all_tests.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ class OracleConfig(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
ORACLE_PORT: Optional[PositiveInt] = Field(
|
ORACLE_PORT: PositiveInt = Field(
|
||||||
description="Port number on which the Oracle database server is listening (default is 1521)",
|
description="Port number on which the Oracle database server is listening (default is 1521)",
|
||||||
default=1521,
|
default=1521,
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@ class PGVectorConfig(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
PGVECTOR_PORT: Optional[PositiveInt] = Field(
|
PGVECTOR_PORT: PositiveInt = Field(
|
||||||
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
description="Port number on which the PostgreSQL server is listening (default is 5433)",
|
||||||
default=5433,
|
default=5433,
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,7 +14,7 @@ class PGVectoRSConfig(BaseSettings):
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
PGVECTO_RS_PORT: Optional[PositiveInt] = Field(
|
PGVECTO_RS_PORT: PositiveInt = Field(
|
||||||
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
description="Port number on which the PostgreSQL server with PGVecto.RS is listening (default is 5431)",
|
||||||
default=5431,
|
default=5431,
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,27 +11,39 @@ class VikingDBConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
VIKINGDB_ACCESS_KEY: Optional[str] = Field(
|
||||||
default=None, description="The Access Key provided by Volcengine VikingDB for API authentication."
|
description="The Access Key provided by Volcengine VikingDB for API authentication."
|
||||||
|
"Refer to the following documentation for details on obtaining credentials:"
|
||||||
|
"https://www.volcengine.com/docs/6291/65568",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
VIKINGDB_SECRET_KEY: Optional[str] = Field(
|
||||||
default=None, description="The Secret Key provided by Volcengine VikingDB for API authentication."
|
description="The Secret Key provided by Volcengine VikingDB for API authentication.",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
VIKINGDB_REGION: Optional[str] = Field(
|
|
||||||
default="cn-shanghai",
|
VIKINGDB_REGION: str = Field(
|
||||||
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
description="The region of the Volcengine VikingDB service.(e.g., 'cn-shanghai', 'cn-beijing').",
|
||||||
|
default="cn-shanghai",
|
||||||
)
|
)
|
||||||
VIKINGDB_HOST: Optional[str] = Field(
|
|
||||||
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
VIKINGDB_HOST: str = Field(
|
||||||
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
description="The host of the Volcengine VikingDB service.(e.g., 'api-vikingdb.volces.com', \
|
||||||
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
'api-vikingdb.mlp.cn-shanghai.volces.com')",
|
||||||
|
default="api-vikingdb.mlp.cn-shanghai.volces.com",
|
||||||
)
|
)
|
||||||
VIKINGDB_SCHEME: Optional[str] = Field(
|
|
||||||
default="http",
|
VIKINGDB_SCHEME: str = Field(
|
||||||
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
description="The scheme of the Volcengine VikingDB service.(e.g., 'http', 'https').",
|
||||||
|
default="http",
|
||||||
)
|
)
|
||||||
VIKINGDB_CONNECTION_TIMEOUT: Optional[int] = Field(
|
|
||||||
default=30, description="The connection timeout of the Volcengine VikingDB service."
|
VIKINGDB_CONNECTION_TIMEOUT: int = Field(
|
||||||
|
description="The connection timeout of the Volcengine VikingDB service.",
|
||||||
|
default=30,
|
||||||
)
|
)
|
||||||
VIKINGDB_SOCKET_TIMEOUT: Optional[int] = Field(
|
|
||||||
default=30, description="The socket timeout of the Volcengine VikingDB service."
|
VIKINGDB_SOCKET_TIMEOUT: int = Field(
|
||||||
|
description="The socket timeout of the Volcengine VikingDB service.",
|
||||||
|
default=30,
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,88 +1,24 @@
|
||||||
import logging
|
from flask_restful import Resource
|
||||||
|
|
||||||
from flask_login import current_user
|
|
||||||
from flask_restful import Resource, marshal, reqparse
|
|
||||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
|
||||||
|
|
||||||
import services
|
|
||||||
from controllers.console import api
|
from controllers.console import api
|
||||||
from controllers.console.app.error import (
|
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||||
CompletionRequestError,
|
|
||||||
ProviderModelCurrentlyNotSupportError,
|
|
||||||
ProviderNotInitializeError,
|
|
||||||
ProviderQuotaExceededError,
|
|
||||||
)
|
|
||||||
from controllers.console.datasets.error import DatasetNotInitializedError
|
|
||||||
from controllers.console.setup import setup_required
|
from controllers.console.setup import setup_required
|
||||||
from controllers.console.wraps import account_initialization_required
|
from controllers.console.wraps import account_initialization_required
|
||||||
from core.errors.error import (
|
|
||||||
LLMBadRequestError,
|
|
||||||
ModelCurrentlyNotSupportError,
|
|
||||||
ProviderTokenNotInitError,
|
|
||||||
QuotaExceededError,
|
|
||||||
)
|
|
||||||
from core.model_runtime.errors.invoke import InvokeError
|
|
||||||
from fields.hit_testing_fields import hit_testing_record_fields
|
|
||||||
from libs.login import login_required
|
from libs.login import login_required
|
||||||
from services.dataset_service import DatasetService
|
|
||||||
from services.hit_testing_service import HitTestingService
|
|
||||||
|
|
||||||
|
|
||||||
class HitTestingApi(Resource):
|
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||||
@setup_required
|
@setup_required
|
||||||
@login_required
|
@login_required
|
||||||
@account_initialization_required
|
@account_initialization_required
|
||||||
def post(self, dataset_id):
|
def post(self, dataset_id):
|
||||||
dataset_id_str = str(dataset_id)
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||||
if dataset is None:
|
args = self.parse_args()
|
||||||
raise NotFound("Dataset not found.")
|
self.hit_testing_args_check(args)
|
||||||
|
|
||||||
try:
|
return self.perform_hit_testing(dataset, args)
|
||||||
DatasetService.check_dataset_permission(dataset, current_user)
|
|
||||||
except services.errors.account.NoPermissionError as e:
|
|
||||||
raise Forbidden(str(e))
|
|
||||||
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("query", type=str, location="json")
|
|
||||||
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
|
||||||
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
HitTestingService.hit_testing_args_check(args)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = HitTestingService.retrieve(
|
|
||||||
dataset=dataset,
|
|
||||||
query=args["query"],
|
|
||||||
account=current_user,
|
|
||||||
retrieval_model=args["retrieval_model"],
|
|
||||||
external_retrieval_model=args["external_retrieval_model"],
|
|
||||||
limit=10,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
|
||||||
except services.errors.index.IndexNotInitializedError:
|
|
||||||
raise DatasetNotInitializedError()
|
|
||||||
except ProviderTokenNotInitError as ex:
|
|
||||||
raise ProviderNotInitializeError(ex.description)
|
|
||||||
except QuotaExceededError:
|
|
||||||
raise ProviderQuotaExceededError()
|
|
||||||
except ModelCurrentlyNotSupportError:
|
|
||||||
raise ProviderModelCurrentlyNotSupportError()
|
|
||||||
except LLMBadRequestError:
|
|
||||||
raise ProviderNotInitializeError(
|
|
||||||
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
|
||||||
"in the Settings -> Model Provider."
|
|
||||||
)
|
|
||||||
except InvokeError as e:
|
|
||||||
raise CompletionRequestError(e.description)
|
|
||||||
except ValueError as e:
|
|
||||||
raise ValueError(str(e))
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception("Hit testing failed.")
|
|
||||||
raise InternalServerError(str(e))
|
|
||||||
|
|
||||||
|
|
||||||
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
||||||
|
|
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
85
api/controllers/console/datasets/hit_testing_base.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from flask_login import current_user
|
||||||
|
from flask_restful import marshal, reqparse
|
||||||
|
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||||
|
|
||||||
|
import services.dataset_service
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
CompletionRequestError,
|
||||||
|
ProviderModelCurrentlyNotSupportError,
|
||||||
|
ProviderNotInitializeError,
|
||||||
|
ProviderQuotaExceededError,
|
||||||
|
)
|
||||||
|
from controllers.console.datasets.error import DatasetNotInitializedError
|
||||||
|
from core.errors.error import (
|
||||||
|
LLMBadRequestError,
|
||||||
|
ModelCurrentlyNotSupportError,
|
||||||
|
ProviderTokenNotInitError,
|
||||||
|
QuotaExceededError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.invoke import InvokeError
|
||||||
|
from fields.hit_testing_fields import hit_testing_record_fields
|
||||||
|
from services.dataset_service import DatasetService
|
||||||
|
from services.hit_testing_service import HitTestingService
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetsHitTestingBase:
|
||||||
|
@staticmethod
|
||||||
|
def get_and_validate_dataset(dataset_id: str):
|
||||||
|
dataset = DatasetService.get_dataset(dataset_id)
|
||||||
|
if dataset is None:
|
||||||
|
raise NotFound("Dataset not found.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
DatasetService.check_dataset_permission(dataset, current_user)
|
||||||
|
except services.errors.account.NoPermissionError as e:
|
||||||
|
raise Forbidden(str(e))
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hit_testing_args_check(args):
|
||||||
|
HitTestingService.hit_testing_args_check(args)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_args():
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
|
||||||
|
parser.add_argument("query", type=str, location="json")
|
||||||
|
parser.add_argument("retrieval_model", type=dict, required=False, location="json")
|
||||||
|
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def perform_hit_testing(dataset, args):
|
||||||
|
try:
|
||||||
|
response = HitTestingService.retrieve(
|
||||||
|
dataset=dataset,
|
||||||
|
query=args["query"],
|
||||||
|
account=current_user,
|
||||||
|
retrieval_model=args["retrieval_model"],
|
||||||
|
external_retrieval_model=args["external_retrieval_model"],
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)}
|
||||||
|
except services.errors.index.IndexNotInitializedError:
|
||||||
|
raise DatasetNotInitializedError()
|
||||||
|
except ProviderTokenNotInitError as ex:
|
||||||
|
raise ProviderNotInitializeError(ex.description)
|
||||||
|
except QuotaExceededError:
|
||||||
|
raise ProviderQuotaExceededError()
|
||||||
|
except ModelCurrentlyNotSupportError:
|
||||||
|
raise ProviderModelCurrentlyNotSupportError()
|
||||||
|
except LLMBadRequestError:
|
||||||
|
raise ProviderNotInitializeError(
|
||||||
|
"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||||
|
"in the Settings -> Model Provider."
|
||||||
|
)
|
||||||
|
except InvokeError as e:
|
||||||
|
raise CompletionRequestError(e.description)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("Hit testing failed.")
|
||||||
|
raise InternalServerError(str(e))
|
|
@ -5,7 +5,6 @@ from libs.external_api import ExternalApi
|
||||||
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
bp = Blueprint("service_api", __name__, url_prefix="/v1")
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
|
|
||||||
from . import index
|
from . import index
|
||||||
from .app import app, audio, completion, conversation, file, message, workflow
|
from .app import app, audio, completion, conversation, file, message, workflow
|
||||||
from .dataset import dataset, document, segment
|
from .dataset import dataset, document, hit_testing, segment
|
||||||
|
|
|
@ -4,7 +4,6 @@ from flask_restful import Resource, reqparse
|
||||||
from werkzeug.exceptions import InternalServerError, NotFound
|
from werkzeug.exceptions import InternalServerError, NotFound
|
||||||
|
|
||||||
import services
|
import services
|
||||||
from constants import UUID_NIL
|
|
||||||
from controllers.service_api import api
|
from controllers.service_api import api
|
||||||
from controllers.service_api.app.error import (
|
from controllers.service_api.app.error import (
|
||||||
AppUnavailableError,
|
AppUnavailableError,
|
||||||
|
@ -108,7 +107,6 @@ class ChatApi(Resource):
|
||||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||||
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json")
|
||||||
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json")
|
||||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, default=UUID_NIL, location="json")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
17
api/controllers/service_api/dataset/hit_testing.py
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
|
||||||
|
from controllers.service_api import api
|
||||||
|
from controllers.service_api.wraps import DatasetApiResource
|
||||||
|
|
||||||
|
|
||||||
|
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
|
||||||
|
def post(self, tenant_id, dataset_id):
|
||||||
|
dataset_id_str = str(dataset_id)
|
||||||
|
|
||||||
|
dataset = self.get_and_validate_dataset(dataset_id_str)
|
||||||
|
args = self.parse_args()
|
||||||
|
self.hit_testing_args_check(args)
|
||||||
|
|
||||||
|
return self.perform_hit_testing(dataset, args)
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(HitTestingApi, "/datasets/<uuid:dataset_id>/hit-testing")
|
|
@ -10,6 +10,7 @@ from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
import contexts
|
import contexts
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||||
|
@ -132,7 +133,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
|
||||||
|
@ -140,7 +141,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Literal, Union, overload
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
|
@ -138,7 +139,7 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||||
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config, user_id=user.id, role=role),
|
||||||
query=query,
|
query=query,
|
||||||
files=file_objs,
|
files=file_objs,
|
||||||
parent_message_id=args.get("parent_message_id"),
|
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
extras=extras,
|
extras=extras,
|
||||||
|
|
|
@ -2,8 +2,9 @@ from collections.abc import Mapping, Sequence
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
from constants import UUID_NIL
|
||||||
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
from core.app.app_config.entities import AppConfig, EasyUIBasedAppConfig, WorkflowUIBasedAppConfig
|
||||||
from core.entities.provider_configuration import ProviderModelBundle
|
from core.entities.provider_configuration import ProviderModelBundle
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
|
@ -116,13 +117,36 @@ class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
class ChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||||
|
"""
|
||||||
|
Base entity for conversation-based app generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
conversation_id: Optional[str] = None
|
||||||
|
parent_message_id: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"Starting from v0.9.0, parent_message_id is used to support message regeneration for internal chat API."
|
||||||
|
"For service API, we need to ensure its forward compatibility, "
|
||||||
|
"so passing in the parent_message_id as request arg is not supported for now. "
|
||||||
|
"It needs to be set to UUID_NIL so that the subsequent processing will treat it as legacy messages."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("parent_message_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_parent_message_id(cls, v, info: ValidationInfo):
|
||||||
|
if info.data.get("invoke_from") == InvokeFrom.SERVICE_API and v != UUID_NIL:
|
||||||
|
raise ValueError("parent_message_id should be UUID_NIL for service API")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
Chat Application Generate Entity.
|
Chat Application Generate Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
pass
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||||
|
@ -133,16 +157,15 @@ class CompletionAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class AgentChatAppGenerateEntity(EasyUIBasedAppGenerateEntity):
|
class AgentChatAppGenerateEntity(ConversationAppGenerateEntity, EasyUIBasedAppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
Agent Chat Application Generate Entity.
|
Agent Chat Application Generate Entity.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
pass
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||||
"""
|
"""
|
||||||
Advanced Chat Application Generate Entity.
|
Advanced Chat Application Generate Entity.
|
||||||
"""
|
"""
|
||||||
|
@ -150,8 +173,6 @@ class AdvancedChatAppGenerateEntity(AppGenerateEntity):
|
||||||
# app config
|
# app config
|
||||||
app_config: WorkflowUIBasedAppConfig
|
app_config: WorkflowUIBasedAppConfig
|
||||||
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
parent_message_id: Optional[str] = None
|
|
||||||
workflow_run_id: Optional[str] = None
|
workflow_run_id: Optional[str] = None
|
||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ supported_model_types:
|
||||||
- text-embedding
|
- text-embedding
|
||||||
configurate_methods:
|
configurate_methods:
|
||||||
- predefined-model
|
- predefined-model
|
||||||
|
- customizable-model
|
||||||
provider_credential_schema:
|
provider_credential_schema:
|
||||||
credential_form_schemas:
|
credential_form_schemas:
|
||||||
- variable: fireworks_api_key
|
- variable: fireworks_api_key
|
||||||
|
@ -28,3 +29,75 @@ provider_credential_schema:
|
||||||
placeholder:
|
placeholder:
|
||||||
zh_Hans: 在此输入您的 API Key
|
zh_Hans: 在此输入您的 API Key
|
||||||
en_US: Enter your API Key
|
en_US: Enter your API Key
|
||||||
|
model_credential_schema:
|
||||||
|
model:
|
||||||
|
label:
|
||||||
|
en_US: Model URL
|
||||||
|
zh_Hans: 模型URL
|
||||||
|
placeholder:
|
||||||
|
en_US: Enter your Model URL
|
||||||
|
zh_Hans: 输入模型URL
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: model_label_zh_Hanns
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型中文名称
|
||||||
|
en_US: The zh_Hans of Model
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型中文名称
|
||||||
|
en_US: Enter your zh_Hans of Model
|
||||||
|
- variable: model_label_en_US
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型英文名称
|
||||||
|
en_US: The en_US of Model
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型英文名称
|
||||||
|
en_US: Enter your en_US of Model
|
||||||
|
- variable: fireworks_api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
||||||
|
- variable: context_size
|
||||||
|
label:
|
||||||
|
zh_Hans: 模型上下文长度
|
||||||
|
en_US: Model context size
|
||||||
|
required: true
|
||||||
|
type: text-input
|
||||||
|
default: '4096'
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的模型上下文长度
|
||||||
|
en_US: Enter your Model context size
|
||||||
|
- variable: max_tokens
|
||||||
|
label:
|
||||||
|
zh_Hans: 最大 token 上限
|
||||||
|
en_US: Upper bound for max tokens
|
||||||
|
default: '4096'
|
||||||
|
type: text-input
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
- variable: function_calling_type
|
||||||
|
label:
|
||||||
|
en_US: Function calling
|
||||||
|
type: select
|
||||||
|
required: false
|
||||||
|
default: no_call
|
||||||
|
options:
|
||||||
|
- value: no_call
|
||||||
|
label:
|
||||||
|
en_US: Not Support
|
||||||
|
zh_Hans: 不支持
|
||||||
|
- value: function_call
|
||||||
|
label:
|
||||||
|
en_US: Support
|
||||||
|
zh_Hans: 支持
|
||||||
|
show_on:
|
||||||
|
- variable: __model_type
|
||||||
|
value: llm
|
||||||
|
|
|
@ -8,7 +8,8 @@ from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, Cho
|
||||||
from openai.types.chat.chat_completion_message import FunctionCall
|
from openai.types.chat.chat_completion_message import FunctionCall
|
||||||
|
|
||||||
from core.model_runtime.callbacks.base_callback import Callback
|
from core.model_runtime.callbacks.base_callback import Callback
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
AssistantPromptMessage,
|
AssistantPromptMessage,
|
||||||
ImagePromptMessageContent,
|
ImagePromptMessageContent,
|
||||||
|
@ -20,6 +21,15 @@ from core.model_runtime.entities.message_entities import (
|
||||||
ToolPromptMessage,
|
ToolPromptMessage,
|
||||||
UserPromptMessage,
|
UserPromptMessage,
|
||||||
)
|
)
|
||||||
|
from core.model_runtime.entities.model_entities import (
|
||||||
|
AIModelEntity,
|
||||||
|
FetchFrom,
|
||||||
|
ModelFeature,
|
||||||
|
ModelPropertyKey,
|
||||||
|
ModelType,
|
||||||
|
ParameterRule,
|
||||||
|
ParameterType,
|
||||||
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
from core.model_runtime.model_providers.fireworks._common import _CommonFireworks
|
||||||
|
@ -608,3 +618,50 @@ class FireworksLargeLanguageModel(_CommonFireworks, LargeLanguageModel):
|
||||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||||
|
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
return AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(
|
||||||
|
en_US=credentials.get("model_label_en_US", model),
|
||||||
|
zh_Hans=credentials.get("model_label_zh_Hanns", model),
|
||||||
|
),
|
||||||
|
model_type=ModelType.LLM,
|
||||||
|
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
|
||||||
|
if credentials.get("function_calling_type") == "function_call"
|
||||||
|
else [],
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={
|
||||||
|
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 4096)),
|
||||||
|
ModelPropertyKey.MODE: LLMMode.CHAT.value,
|
||||||
|
},
|
||||||
|
parameter_rules=[
|
||||||
|
ParameterRule(
|
||||||
|
name="temperature",
|
||||||
|
use_template="temperature",
|
||||||
|
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="max_tokens",
|
||||||
|
use_template="max_tokens",
|
||||||
|
default=512,
|
||||||
|
min=1,
|
||||||
|
max=int(credentials.get("max_tokens", 4096)),
|
||||||
|
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
|
||||||
|
type=ParameterType.INT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_p",
|
||||||
|
use_template="top_p",
|
||||||
|
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
ParameterRule(
|
||||||
|
name="top_k",
|
||||||
|
use_template="top_k",
|
||||||
|
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
|
||||||
|
type=ParameterType.FLOAT,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,46 @@
|
||||||
|
model: accounts/fireworks/models/qwen2p5-72b-instruct
|
||||||
|
label:
|
||||||
|
zh_Hans: Qwen2.5 72B Instruct
|
||||||
|
en_US: Qwen2.5 72B Instruct
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- agent-thought
|
||||||
|
- tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 32768
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
- name: top_p
|
||||||
|
use_template: top_p
|
||||||
|
- name: top_k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
- name: context_length_exceeded_behavior
|
||||||
|
default: None
|
||||||
|
label:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
help:
|
||||||
|
zh_Hans: 上下文长度超出行为
|
||||||
|
en_US: Context Length Exceeded Behavior
|
||||||
|
type: string
|
||||||
|
options:
|
||||||
|
- None
|
||||||
|
- truncate
|
||||||
|
- error
|
||||||
|
- name: response_format
|
||||||
|
use_template: response_format
|
||||||
|
pricing:
|
||||||
|
input: '0.9'
|
||||||
|
output: '0.9'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -61,7 +61,8 @@ class MinimaxTextEmbeddingModel(TextEmbeddingModel):
|
||||||
url = f"{self.api_base}?GroupId={group_id}"
|
url = f"{self.api_base}?GroupId={group_id}"
|
||||||
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||||
|
|
||||||
data = {"model": "embo-01", "texts": texts, "type": "db"}
|
embedding_type = "db" if input_type == EmbeddingInputType.DOCUMENT else "query"
|
||||||
|
data = {"model": "embo-01", "texts": texts, "type": embedding_type}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = post(url, headers=headers, data=dumps(data))
|
response = post(url, headers=headers, data=dumps(data))
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
|
||||||
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
from core.rag.datasource.keyword.keyword_base import BaseKeyword
|
||||||
|
from core.rag.datasource.keyword.keyword_type import KeyWordType
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
|
@ -13,16 +13,19 @@ class Keyword:
|
||||||
self._keyword_processor = self._init_keyword()
|
self._keyword_processor = self._init_keyword()
|
||||||
|
|
||||||
def _init_keyword(self) -> BaseKeyword:
|
def _init_keyword(self) -> BaseKeyword:
|
||||||
config = dify_config
|
keyword_type = dify_config.KEYWORD_STORE
|
||||||
keyword_type = config.KEYWORD_STORE
|
keyword_factory = self.get_keyword_factory(keyword_type)
|
||||||
|
return keyword_factory(self._dataset)
|
||||||
|
|
||||||
if not keyword_type:
|
@staticmethod
|
||||||
raise ValueError("Keyword store must be specified.")
|
def get_keyword_factory(keyword_type: str) -> type[BaseKeyword]:
|
||||||
|
match keyword_type:
|
||||||
|
case KeyWordType.JIEBA:
|
||||||
|
from core.rag.datasource.keyword.jieba.jieba import Jieba
|
||||||
|
|
||||||
if keyword_type == "jieba":
|
return Jieba
|
||||||
return Jieba(dataset=self._dataset)
|
case _:
|
||||||
else:
|
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
||||||
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
|
||||||
|
|
||||||
def create(self, texts: list[Document], **kwargs):
|
def create(self, texts: list[Document], **kwargs):
|
||||||
self._keyword_processor.create(texts, **kwargs)
|
self._keyword_processor.create(texts, **kwargs)
|
||||||
|
|
5
api/core/rag/datasource/keyword/keyword_type.py
Normal file
5
api/core/rag/datasource/keyword/keyword_type.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class KeyWordType(str, Enum):
|
||||||
|
JIEBA = "jieba"
|
|
@ -1,5 +1,4 @@
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import closing
|
|
||||||
|
|
||||||
import oss2 as aliyun_s3
|
import oss2 as aliyun_s3
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
|
@ -34,15 +33,15 @@ class AliyunOssStorage(BaseStorage):
|
||||||
self.client.put_object(self.__wrapper_folder_filename(filename), data)
|
self.client.put_object(self.__wrapper_folder_filename(filename), data)
|
||||||
|
|
||||||
def load_once(self, filename: str) -> bytes:
|
def load_once(self, filename: str) -> bytes:
|
||||||
with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj:
|
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||||
data = obj.read()
|
data = obj.read()
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def load_stream(self, filename: str) -> Generator:
|
def load_stream(self, filename: str) -> Generator:
|
||||||
def generate(filename: str = filename) -> Generator:
|
def generate(filename: str = filename) -> Generator:
|
||||||
with closing(self.client.get_object(self.__wrapper_folder_filename(filename))) as obj:
|
obj = self.client.get_object(self.__wrapper_folder_filename(filename))
|
||||||
while chunk := obj.read(4096):
|
while chunk := obj.read(4096):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
return generate()
|
return generate()
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import closing
|
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.client import Config
|
from botocore.client import Config
|
||||||
|
@ -55,8 +54,7 @@ class AwsS3Storage(BaseStorage):
|
||||||
|
|
||||||
def load_once(self, filename: str) -> bytes:
|
def load_once(self, filename: str) -> bytes:
|
||||||
try:
|
try:
|
||||||
with closing(self.client) as client:
|
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||||
data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
|
||||||
except ClientError as ex:
|
except ClientError as ex:
|
||||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||||
raise FileNotFoundError("File not found")
|
raise FileNotFoundError("File not found")
|
||||||
|
@ -67,9 +65,8 @@ class AwsS3Storage(BaseStorage):
|
||||||
def load_stream(self, filename: str) -> Generator:
|
def load_stream(self, filename: str) -> Generator:
|
||||||
def generate(filename: str = filename) -> Generator:
|
def generate(filename: str = filename) -> Generator:
|
||||||
try:
|
try:
|
||||||
with closing(self.client) as client:
|
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||||
response = client.get_object(Bucket=self.bucket_name, Key=filename)
|
yield from response["Body"].iter_chunks()
|
||||||
yield from response["Body"].iter_chunks()
|
|
||||||
except ClientError as ex:
|
except ClientError as ex:
|
||||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||||
raise FileNotFoundError("File not found")
|
raise FileNotFoundError("File not found")
|
||||||
|
@ -79,16 +76,14 @@ class AwsS3Storage(BaseStorage):
|
||||||
return generate()
|
return generate()
|
||||||
|
|
||||||
def download(self, filename, target_filepath):
|
def download(self, filename, target_filepath):
|
||||||
with closing(self.client) as client:
|
self.client.download_file(self.bucket_name, filename, target_filepath)
|
||||||
client.download_file(self.bucket_name, filename, target_filepath)
|
|
||||||
|
|
||||||
def exists(self, filename):
|
def exists(self, filename):
|
||||||
with closing(self.client) as client:
|
try:
|
||||||
try:
|
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||||
client.head_object(Bucket=self.bucket_name, Key=filename)
|
return True
|
||||||
return True
|
except:
|
||||||
except:
|
return False
|
||||||
return False
|
|
||||||
|
|
||||||
def delete(self, filename):
|
def delete(self, filename):
|
||||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||||
|
|
|
@ -2,7 +2,6 @@ import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import closing
|
|
||||||
|
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from google.cloud import storage as google_cloud_storage
|
from google.cloud import storage as google_cloud_storage
|
||||||
|
@ -43,7 +42,7 @@ class GoogleCloudStorage(BaseStorage):
|
||||||
def generate(filename: str = filename) -> Generator:
|
def generate(filename: str = filename) -> Generator:
|
||||||
bucket = self.client.get_bucket(self.bucket_name)
|
bucket = self.client.get_bucket(self.bucket_name)
|
||||||
blob = bucket.get_blob(filename)
|
blob = bucket.get_blob(filename)
|
||||||
with closing(blob.open(mode="rb")) as blob_stream:
|
with blob.open(mode="rb") as blob_stream:
|
||||||
while chunk := blob_stream.read(4096):
|
while chunk := blob_stream.read(4096):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import closing
|
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
@ -28,8 +27,7 @@ class OracleOCIStorage(BaseStorage):
|
||||||
|
|
||||||
def load_once(self, filename: str) -> bytes:
|
def load_once(self, filename: str) -> bytes:
|
||||||
try:
|
try:
|
||||||
with closing(self.client) as client:
|
data = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
||||||
data = client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
|
|
||||||
except ClientError as ex:
|
except ClientError as ex:
|
||||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||||
raise FileNotFoundError("File not found")
|
raise FileNotFoundError("File not found")
|
||||||
|
@ -40,9 +38,8 @@ class OracleOCIStorage(BaseStorage):
|
||||||
def load_stream(self, filename: str) -> Generator:
|
def load_stream(self, filename: str) -> Generator:
|
||||||
def generate(filename: str = filename) -> Generator:
|
def generate(filename: str = filename) -> Generator:
|
||||||
try:
|
try:
|
||||||
with closing(self.client) as client:
|
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||||
response = client.get_object(Bucket=self.bucket_name, Key=filename)
|
yield from response["Body"].iter_chunks()
|
||||||
yield from response["Body"].iter_chunks()
|
|
||||||
except ClientError as ex:
|
except ClientError as ex:
|
||||||
if ex.response["Error"]["Code"] == "NoSuchKey":
|
if ex.response["Error"]["Code"] == "NoSuchKey":
|
||||||
raise FileNotFoundError("File not found")
|
raise FileNotFoundError("File not found")
|
||||||
|
@ -52,16 +49,14 @@ class OracleOCIStorage(BaseStorage):
|
||||||
return generate()
|
return generate()
|
||||||
|
|
||||||
def download(self, filename, target_filepath):
|
def download(self, filename, target_filepath):
|
||||||
with closing(self.client) as client:
|
self.client.download_file(self.bucket_name, filename, target_filepath)
|
||||||
client.download_file(self.bucket_name, filename, target_filepath)
|
|
||||||
|
|
||||||
def exists(self, filename):
|
def exists(self, filename):
|
||||||
with closing(self.client) as client:
|
try:
|
||||||
try:
|
self.client.head_object(Bucket=self.bucket_name, Key=filename)
|
||||||
client.head_object(Bucket=self.bucket_name, Key=filename)
|
return True
|
||||||
return True
|
except:
|
||||||
except:
|
return False
|
||||||
return False
|
|
||||||
|
|
||||||
def delete(self, filename):
|
def delete(self, filename):
|
||||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||||
|
|
|
@ -1,15 +1,25 @@
|
||||||
from services.auth.firecrawl import FirecrawlAuth
|
from services.auth.api_key_auth_base import ApiKeyAuthBase
|
||||||
from services.auth.jina import JinaAuth
|
from services.auth.auth_type import AuthType
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyAuthFactory:
|
class ApiKeyAuthFactory:
|
||||||
def __init__(self, provider: str, credentials: dict):
|
def __init__(self, provider: str, credentials: dict):
|
||||||
if provider == "firecrawl":
|
auth_factory = self.get_apikey_auth_factory(provider)
|
||||||
self.auth = FirecrawlAuth(credentials)
|
self.auth = auth_factory(credentials)
|
||||||
elif provider == "jinareader":
|
|
||||||
self.auth = JinaAuth(credentials)
|
|
||||||
else:
|
|
||||||
raise ValueError("Invalid provider")
|
|
||||||
|
|
||||||
def validate_credentials(self):
|
def validate_credentials(self):
|
||||||
return self.auth.validate_credentials()
|
return self.auth.validate_credentials()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
|
||||||
|
match provider:
|
||||||
|
case AuthType.FIRECRAWL:
|
||||||
|
from services.auth.firecrawl.firecrawl import FirecrawlAuth
|
||||||
|
|
||||||
|
return FirecrawlAuth
|
||||||
|
case AuthType.JINA:
|
||||||
|
from services.auth.jina.jina import JinaAuth
|
||||||
|
|
||||||
|
return JinaAuth
|
||||||
|
case _:
|
||||||
|
raise ValueError("Invalid provider")
|
||||||
|
|
6
api/services/auth/auth_type.py
Normal file
6
api/services/auth/auth_type.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class AuthType(str, Enum):
|
||||||
|
FIRECRAWL = "firecrawl"
|
||||||
|
JINA = "jinareader"
|
0
api/services/auth/firecrawl/__init__.py
Normal file
0
api/services/auth/firecrawl/__init__.py
Normal file
0
api/services/auth/jina/__init__.py
Normal file
0
api/services/auth/jina/__init__.py
Normal file
|
@ -1050,6 +1050,151 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
<Heading
|
||||||
|
url='/datasets/{dataset_id}/hit_testing'
|
||||||
|
method='POST'
|
||||||
|
title='Dataset hit testing'
|
||||||
|
name='#dataset_hit_testing'
|
||||||
|
/>
|
||||||
|
<Row>
|
||||||
|
<Col>
|
||||||
|
### Path
|
||||||
|
<Properties>
|
||||||
|
<Property name='dataset_id' type='string' key='dataset_id'>
|
||||||
|
Dataset ID
|
||||||
|
</Property>
|
||||||
|
</Properties>
|
||||||
|
|
||||||
|
### Request Body
|
||||||
|
<Properties>
|
||||||
|
<Property name='query' type='string' key='query'>
|
||||||
|
retrieval keywordc
|
||||||
|
</Property>
|
||||||
|
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||||
|
retrieval keyword(Optional, if not filled, it will be recalled according to the default method)
|
||||||
|
- <code>search_method</code> (text) Search method: One of the following four keywords is required
|
||||||
|
- <code>keyword_search</code> Keyword search
|
||||||
|
- <code>semantic_search</code> Semantic search
|
||||||
|
- <code>full_text_search</code> Full-text search
|
||||||
|
- <code>hybrid_search</code> Hybrid search
|
||||||
|
- <code>reranking_enable</code> (bool) Whether to enable reranking, optional, required if the search mode is semantic_search or hybrid_search
|
||||||
|
- <code>reranking_mode</code> (object) Rerank model configuration, optional, required if reranking is enabled
|
||||||
|
- <code>reranking_provider_name</code> (string) Rerank model provider
|
||||||
|
- <code>reranking_model_name</code> (string) Rerank model name
|
||||||
|
- <code>weights</code> (double) Semantic search weight setting in hybrid search mode
|
||||||
|
- <code>top_k</code> (integer) Number of results to return, optional
|
||||||
|
- <code>score_threshold_enabled</code> (bool) Whether to enable score threshold
|
||||||
|
- <code>score_threshold</code> (double) Score threshold
|
||||||
|
</Property>
|
||||||
|
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
|
||||||
|
Unused field
|
||||||
|
</Property>
|
||||||
|
</Properties>
|
||||||
|
</Col>
|
||||||
|
<Col sticky>
|
||||||
|
<CodeGroup
|
||||||
|
title="Request"
|
||||||
|
tag="POST"
|
||||||
|
label="/datasets/{dataset_id}/hit_testing"
|
||||||
|
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||||
|
"query": "test",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 1,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
}
|
||||||
|
}'`}
|
||||||
|
>
|
||||||
|
```bash {{ title: 'cURL' }}
|
||||||
|
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
|
||||||
|
--header 'Authorization: Bearer {api_key}' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data-raw '{
|
||||||
|
"query": "test",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
<CodeGroup title="Response">
|
||||||
|
```json {{ title: 'Response' }}
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"content": "test"
|
||||||
|
},
|
||||||
|
"records": [
|
||||||
|
{
|
||||||
|
"segment": {
|
||||||
|
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
|
||||||
|
"position": 1,
|
||||||
|
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||||
|
"content": "Operation guide",
|
||||||
|
"answer": null,
|
||||||
|
"word_count": 847,
|
||||||
|
"tokens": 280,
|
||||||
|
"keywords": [
|
||||||
|
"install",
|
||||||
|
"java",
|
||||||
|
"base",
|
||||||
|
"scripts",
|
||||||
|
"jdk",
|
||||||
|
"manual",
|
||||||
|
"internal",
|
||||||
|
"opens",
|
||||||
|
"add",
|
||||||
|
"vmoptions"
|
||||||
|
],
|
||||||
|
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
|
||||||
|
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
|
||||||
|
"hit_count": 0,
|
||||||
|
"enabled": true,
|
||||||
|
"disabled_at": null,
|
||||||
|
"disabled_by": null,
|
||||||
|
"status": "completed",
|
||||||
|
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
|
||||||
|
"created_at": 1728734540,
|
||||||
|
"indexing_at": 1728734552,
|
||||||
|
"completed_at": 1728734584,
|
||||||
|
"error": null,
|
||||||
|
"stopped_at": null,
|
||||||
|
"document": {
|
||||||
|
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||||
|
"data_source_type": "upload_file",
|
||||||
|
"name": "readme.txt",
|
||||||
|
"doc_type": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"score": 3.730463140527718e-05,
|
||||||
|
"tsne_position": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
<Row>
|
<Row>
|
||||||
<Col>
|
<Col>
|
||||||
### Error message
|
### Error message
|
||||||
|
|
|
@ -1049,6 +1049,152 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
||||||
</Col>
|
</Col>
|
||||||
</Row>
|
</Row>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<Heading
|
||||||
|
url='/datasets/{dataset_id}/hit_testing'
|
||||||
|
method='POST'
|
||||||
|
title='知识库召回测试'
|
||||||
|
name='#dataset_hit_testing'
|
||||||
|
/>
|
||||||
|
<Row>
|
||||||
|
<Col>
|
||||||
|
### Path
|
||||||
|
<Properties>
|
||||||
|
<Property name='dataset_id' type='string' key='dataset_id'>
|
||||||
|
知识库 ID
|
||||||
|
</Property>
|
||||||
|
</Properties>
|
||||||
|
|
||||||
|
### Request Body
|
||||||
|
<Properties>
|
||||||
|
<Property name='query' type='string' key='query'>
|
||||||
|
召回关键词
|
||||||
|
</Property>
|
||||||
|
<Property name='retrieval_model' type='object' key='retrieval_model'>
|
||||||
|
召回参数(选填,如不填,按照默认方式召回)
|
||||||
|
- <code>search_method</code> (text) 检索方法:以下三个关键字之一,必填
|
||||||
|
- <code>keyword_search</code> 关键字检索
|
||||||
|
- <code>semantic_search</code> 语义检索
|
||||||
|
- <code>full_text_search</code> 全文检索
|
||||||
|
- <code>hybrid_search</code> 混合检索
|
||||||
|
- <code>reranking_enable</code> (bool) 是否启用 Reranking,非必填,如果检索模式为semantic_search模式或者hybrid_search则传值
|
||||||
|
- <code>reranking_mode</code> (object) Rerank模型配置,非必填,如果启用了 reranking 则传值
|
||||||
|
- <code>reranking_provider_name</code> (string) Rerank 模型提供商
|
||||||
|
- <code>reranking_model_name</code> (string) Rerank 模型名称
|
||||||
|
- <code>weights</code> (double) 混合检索模式下语意检索的权重设置
|
||||||
|
- <code>top_k</code> (integer) 返回结果数量,非必填
|
||||||
|
- <code>score_threshold_enabled</code> (bool) 是否开启Score阈值
|
||||||
|
- <code>score_threshold</code> (double) Score阈值
|
||||||
|
</Property>
|
||||||
|
<Property name='external_retrieval_model' type='object' key='external_retrieval_model'>
|
||||||
|
未启用字段
|
||||||
|
</Property>
|
||||||
|
</Properties>
|
||||||
|
</Col>
|
||||||
|
<Col sticky>
|
||||||
|
<CodeGroup
|
||||||
|
title="Request"
|
||||||
|
tag="POST"
|
||||||
|
label="/datasets/{dataset_id}/hit_testing"
|
||||||
|
targetCode={`curl --location --request GET '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \\\n--header 'Authorization: Bearer {api_key}'\\\n--header 'Content-Type: application/json'\\\n--data-raw '{
|
||||||
|
"query": "test",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 1,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
}
|
||||||
|
}'`}
|
||||||
|
>
|
||||||
|
```bash {{ title: 'cURL' }}
|
||||||
|
curl --location --request POST '${props.apiBaseUrl}/datasets/{dataset_id}/hit_testing' \
|
||||||
|
--header 'Authorization: Bearer {api_key}' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data-raw '{
|
||||||
|
"query": "test",
|
||||||
|
"retrieval_model": {
|
||||||
|
"search_method": "keyword_search",
|
||||||
|
"reranking_enable": false,
|
||||||
|
"reranking_mode": null,
|
||||||
|
"reranking_model": {
|
||||||
|
"reranking_provider_name": "",
|
||||||
|
"reranking_model_name": ""
|
||||||
|
},
|
||||||
|
"weights": null,
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": false,
|
||||||
|
"score_threshold": null
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
<CodeGroup title="Response">
|
||||||
|
```json {{ title: 'Response' }}
|
||||||
|
{
|
||||||
|
"query": {
|
||||||
|
"content": "test"
|
||||||
|
},
|
||||||
|
"records": [
|
||||||
|
{
|
||||||
|
"segment": {
|
||||||
|
"id": "7fa6f24f-8679-48b3-bc9d-bdf28d73f218",
|
||||||
|
"position": 1,
|
||||||
|
"document_id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||||
|
"content": "Operation guide",
|
||||||
|
"answer": null,
|
||||||
|
"word_count": 847,
|
||||||
|
"tokens": 280,
|
||||||
|
"keywords": [
|
||||||
|
"install",
|
||||||
|
"java",
|
||||||
|
"base",
|
||||||
|
"scripts",
|
||||||
|
"jdk",
|
||||||
|
"manual",
|
||||||
|
"internal",
|
||||||
|
"opens",
|
||||||
|
"add",
|
||||||
|
"vmoptions"
|
||||||
|
],
|
||||||
|
"index_node_id": "39dd8443-d960-45a8-bb46-7275ad7fbc8e",
|
||||||
|
"index_node_hash": "0189157697b3c6a418ccf8264a09699f25858975578f3467c76d6bfc94df1d73",
|
||||||
|
"hit_count": 0,
|
||||||
|
"enabled": true,
|
||||||
|
"disabled_at": null,
|
||||||
|
"disabled_by": null,
|
||||||
|
"status": "completed",
|
||||||
|
"created_by": "dbcb1ab5-90c8-41a7-8b78-73b235eb6f6f",
|
||||||
|
"created_at": 1728734540,
|
||||||
|
"indexing_at": 1728734552,
|
||||||
|
"completed_at": 1728734584,
|
||||||
|
"error": null,
|
||||||
|
"stopped_at": null,
|
||||||
|
"document": {
|
||||||
|
"id": "a8c6c36f-9f5d-4d7a-8472-f5d7b75d71d2",
|
||||||
|
"data_source_type": "upload_file",
|
||||||
|
"name": "readme.txt",
|
||||||
|
"doc_type": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"score": 3.730463140527718e-05,
|
||||||
|
"tsne_position": null
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
</CodeGroup>
|
||||||
|
</Col>
|
||||||
|
</Row>
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
<Row>
|
<Row>
|
||||||
|
|
|
@ -63,7 +63,7 @@ const ConfigContent: FC<Props> = ({
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
currentModel,
|
currentModel: currentRerankModel,
|
||||||
} = useCurrentProviderAndModel(
|
} = useCurrentProviderAndModel(
|
||||||
rerankModelList,
|
rerankModelList,
|
||||||
rerankDefaultModel
|
rerankDefaultModel
|
||||||
|
@ -74,11 +74,6 @@ const ConfigContent: FC<Props> = ({
|
||||||
: undefined,
|
: undefined,
|
||||||
)
|
)
|
||||||
|
|
||||||
const handleDisabledSwitchClick = useCallback(() => {
|
|
||||||
if (!currentModel)
|
|
||||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
|
||||||
}, [currentModel, rerankDefaultModel, t])
|
|
||||||
|
|
||||||
const rerankModel = (() => {
|
const rerankModel = (() => {
|
||||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
||||||
return {
|
return {
|
||||||
|
@ -164,12 +159,33 @@ const ConfigContent: FC<Props> = ({
|
||||||
const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
|
const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
|
||||||
const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
|
const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
|
||||||
|
|
||||||
|
const canManuallyToggleRerank = useMemo(() => {
|
||||||
|
return !(
|
||||||
|
(selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|
||||||
|
|| selectedDatasetsMode.allExternal
|
||||||
|
)
|
||||||
|
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
|
||||||
|
|
||||||
const showRerankModel = useMemo(() => {
|
const showRerankModel = useMemo(() => {
|
||||||
if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic)
|
if (!canManuallyToggleRerank)
|
||||||
return false
|
return false
|
||||||
|
|
||||||
return true
|
return datasetConfigs.reranking_enable
|
||||||
}, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic])
|
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
|
||||||
|
|
||||||
|
const handleDisabledSwitchClick = useCallback(() => {
|
||||||
|
if (!currentRerankModel && !showRerankModel)
|
||||||
|
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||||
|
}, [currentRerankModel, showRerankModel, t])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
|
||||||
|
onChange({
|
||||||
|
...datasetConfigs,
|
||||||
|
reranking_enable: showRerankModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div>
|
<div>
|
||||||
|
@ -256,13 +272,15 @@ const ConfigContent: FC<Props> = ({
|
||||||
>
|
>
|
||||||
<Switch
|
<Switch
|
||||||
size='md'
|
size='md'
|
||||||
defaultValue={currentModel ? showRerankModel : false}
|
defaultValue={showRerankModel}
|
||||||
disabled={!currentModel}
|
disabled={!currentRerankModel || !canManuallyToggleRerank}
|
||||||
onChange={(v) => {
|
onChange={(v) => {
|
||||||
onChange({
|
if (canManuallyToggleRerank) {
|
||||||
...datasetConfigs,
|
onChange({
|
||||||
reranking_enable: v,
|
...datasetConfigs,
|
||||||
})
|
reranking_enable: v,
|
||||||
|
})
|
||||||
|
}
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -42,6 +42,7 @@ const ParamsConfig = ({
|
||||||
allHighQuality,
|
allHighQuality,
|
||||||
allHighQualityFullTextSearch,
|
allHighQualityFullTextSearch,
|
||||||
allHighQualityVectorSearch,
|
allHighQualityVectorSearch,
|
||||||
|
allInternal,
|
||||||
allExternal,
|
allExternal,
|
||||||
mixtureHighQualityAndEconomic,
|
mixtureHighQualityAndEconomic,
|
||||||
inconsistentEmbeddingModel,
|
inconsistentEmbeddingModel,
|
||||||
|
@ -50,7 +51,7 @@ const ParamsConfig = ({
|
||||||
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
||||||
let rerankEnable = restConfigs.reranking_enable
|
let rerankEnable = restConfigs.reranking_enable
|
||||||
|
|
||||||
if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal)
|
if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
|
||||||
rerankEnable = false
|
rerankEnable = false
|
||||||
|
|
||||||
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
|
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
|
||||||
|
|
|
@ -1,25 +1,17 @@
|
||||||
import { useCallback } from 'react'
|
import { useCallback } from 'react'
|
||||||
import { useStoreApi } from 'reactflow'
|
import { useStoreApi } from 'reactflow'
|
||||||
import { useTranslation } from 'react-i18next'
|
|
||||||
import { useWorkflowStore } from '../store'
|
import { useWorkflowStore } from '../store'
|
||||||
import {
|
import {
|
||||||
BlockEnum,
|
BlockEnum,
|
||||||
WorkflowRunningStatus,
|
WorkflowRunningStatus,
|
||||||
} from '../types'
|
} from '../types'
|
||||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
|
||||||
import type { Node } from '../types'
|
|
||||||
import { useWorkflow } from './use-workflow'
|
|
||||||
import {
|
import {
|
||||||
useIsChatMode,
|
useIsChatMode,
|
||||||
useNodesSyncDraft,
|
useNodesSyncDraft,
|
||||||
useWorkflowInteractions,
|
useWorkflowInteractions,
|
||||||
useWorkflowRun,
|
useWorkflowRun,
|
||||||
} from './index'
|
} from './index'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
||||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
||||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||||
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
|
|
||||||
import Toast from '@/app/components/base/toast'
|
|
||||||
|
|
||||||
export const useWorkflowStartRun = () => {
|
export const useWorkflowStartRun = () => {
|
||||||
const store = useStoreApi()
|
const store = useStoreApi()
|
||||||
|
@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => {
|
||||||
const isChatMode = useIsChatMode()
|
const isChatMode = useIsChatMode()
|
||||||
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
||||||
const { handleRun } = useWorkflowRun()
|
const { handleRun } = useWorkflowRun()
|
||||||
const { isFromStartNode } = useWorkflow()
|
|
||||||
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
||||||
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
|
|
||||||
const { t } = useTranslation()
|
|
||||||
const {
|
|
||||||
modelList: rerankModelList,
|
|
||||||
defaultModel: rerankDefaultModel,
|
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
|
||||||
|
|
||||||
const {
|
|
||||||
currentModel,
|
|
||||||
} = useCurrentProviderAndModel(
|
|
||||||
rerankModelList,
|
|
||||||
rerankDefaultModel
|
|
||||||
? {
|
|
||||||
...rerankDefaultModel,
|
|
||||||
provider: rerankDefaultModel.provider.provider,
|
|
||||||
}
|
|
||||||
: undefined,
|
|
||||||
)
|
|
||||||
|
|
||||||
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
||||||
const {
|
const {
|
||||||
|
@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => {
|
||||||
const { getNodes } = store.getState()
|
const { getNodes } = store.getState()
|
||||||
const nodes = getNodes()
|
const nodes = getNodes()
|
||||||
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
||||||
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
|
|
||||||
node.data.type === BlockEnum.KnowledgeRetrieval,
|
|
||||||
)
|
|
||||||
const startVariables = startNode?.data.variables || []
|
const startVariables = startNode?.data.variables || []
|
||||||
const fileSettings = featuresStore!.getState().features.file
|
const fileSettings = featuresStore!.getState().features.file
|
||||||
const {
|
const {
|
||||||
|
@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => {
|
||||||
setShowEnvPanel,
|
setShowEnvPanel,
|
||||||
} = workflowStore.getState()
|
} = workflowStore.getState()
|
||||||
|
|
||||||
if (knowledgeRetrievalNodes.length > 0) {
|
|
||||||
for (const node of knowledgeRetrievalNodes) {
|
|
||||||
if (isFromStartNode(node.id)) {
|
|
||||||
const res = checkKnowledgeRetrievalValid(node.data, t)
|
|
||||||
if (!res.isValid || !currentModel || !rerankDefaultModel) {
|
|
||||||
const errorMessage = res.errorMessage
|
|
||||||
if (errorMessage) {
|
|
||||||
Toast.notify({
|
|
||||||
type: 'error',
|
|
||||||
message: errorMessage,
|
|
||||||
})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
Toast.notify({
|
|
||||||
type: 'error',
|
|
||||||
message: t('appDebug.datasetConfig.rerankModelRequired'),
|
|
||||||
})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
setShowEnvPanel(false)
|
setShowEnvPanel(false)
|
||||||
|
|
||||||
if (showDebugAndPreviewPanel) {
|
if (showDebugAndPreviewPanel) {
|
||||||
|
|
|
@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
|
||||||
import { fetchDatasets } from '@/service/datasets'
|
import { fetchDatasets } from '@/service/datasets'
|
||||||
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
||||||
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
||||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
|
|
||||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
|
@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
const startNodeId = startNode?.id
|
const startNodeId = startNode?.id
|
||||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||||
|
|
||||||
|
const inputRef = useRef(inputs)
|
||||||
|
|
||||||
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
||||||
const newInputs = produce(s, (draft) => {
|
const newInputs = produce(s, (draft) => {
|
||||||
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
||||||
|
@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
})
|
})
|
||||||
// not work in pass to draft...
|
// not work in pass to draft...
|
||||||
doSetInputs(newInputs)
|
doSetInputs(newInputs)
|
||||||
|
inputRef.current = newInputs
|
||||||
}, [doSetInputs])
|
}, [doSetInputs])
|
||||||
|
|
||||||
const inputRef = useRef(inputs)
|
|
||||||
useEffect(() => {
|
|
||||||
inputRef.current = inputs
|
|
||||||
}, [inputs])
|
|
||||||
|
|
||||||
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
||||||
const newInputs = produce(inputs, (draft) => {
|
const newInputs = produce(inputs, (draft) => {
|
||||||
draft.query_variable_selector = newVar as ValueSelector
|
draft.query_variable_selector = newVar as ValueSelector
|
||||||
|
@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
||||||
|
|
||||||
const {
|
const {
|
||||||
|
modelList: rerankModelList,
|
||||||
defaultModel: rerankDefaultModel,
|
defaultModel: rerankDefaultModel,
|
||||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||||
|
|
||||||
|
const {
|
||||||
|
currentModel: currentRerankModel,
|
||||||
|
} = useCurrentProviderAndModel(
|
||||||
|
rerankModelList,
|
||||||
|
rerankDefaultModel
|
||||||
|
? {
|
||||||
|
...rerankDefaultModel,
|
||||||
|
provider: rerankDefaultModel.provider.provider,
|
||||||
|
}
|
||||||
|
: undefined,
|
||||||
|
)
|
||||||
|
|
||||||
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||||
const newInputs = produce(inputRef.current, (draft) => {
|
const newInputs = produce(inputRef.current, (draft) => {
|
||||||
if (!draft.single_retrieval_config) {
|
if (!draft.single_retrieval_config) {
|
||||||
|
@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
// set defaults models
|
// set defaults models
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const inputs = inputRef.current
|
const inputs = inputRef.current
|
||||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
|
||||||
return
|
return
|
||||||
|
|
||||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
||||||
|
@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||||
draft.multiple_retrieval_config = {
|
draft.multiple_retrieval_config = {
|
||||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||||
|
@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
reranking_model: multipleRetrievalConfig?.reranking_model,
|
reranking_model: multipleRetrievalConfig?.reranking_model,
|
||||||
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
||||||
weights: multipleRetrievalConfig?.weights,
|
weights: multipleRetrievalConfig?.weights,
|
||||||
|
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
|
||||||
|
? multipleRetrievalConfig.reranking_enable
|
||||||
|
: Boolean(currentRerankModel && rerankDefaultModel),
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
setInputs(newInput)
|
setInputs(newInput)
|
||||||
|
@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
const inputs = inputRef.current
|
||||||
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
||||||
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
||||||
query_variable_selector = [startNodeId, 'sys.query']
|
query_variable_selector = [startNodeId, 'sys.query']
|
||||||
|
|
||||||
setInputs({
|
setInputs(produce(inputs, (draft) => {
|
||||||
...inputs,
|
draft.query_variable_selector = query_variable_selector
|
||||||
query_variable_selector,
|
}))
|
||||||
})
|
|
||||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
|
|
@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
|
||||||
reranking_mode,
|
reranking_mode,
|
||||||
reranking_model,
|
reranking_model,
|
||||||
weights,
|
weights,
|
||||||
reranking_enable: allEconomic ? reranking_enable : true,
|
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
|
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user