add rerank check when doing mutil-retrieval (#9998)

This commit is contained in:
Jyong 2024-10-30 11:17:39 +08:00 committed by GitHub
parent 5ad5d0cff4
commit 9ebd453b87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 32 additions and 2 deletions

View File

@ -1,6 +1,6 @@
from enum import Enum
class RerankMode(Enum):
class RerankMode(str, Enum):
RERANKING_MODEL = "reranking_model"
WEIGHTED_SCORE = "weighted_score"

View File

@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
@ -361,10 +362,39 @@ class DatasetRetrieval:
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
if not available_datasets:
return []
threads = []
all_documents = []
dataset_ids = [dataset.id for dataset in available_datasets]
index_type = None
index_type_check = all(
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
)
if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
raise ValueError(
"The configured knowledge base list have different indexing technique, please set reranking model."
)
index_type = available_datasets[0].indexing_technique
if index_type == "high_quality":
embedding_model_check = all(
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
)
embedding_model_provider_check = all(
item.embedding_model_provider == available_datasets[0].embedding_model_provider
for item in available_datasets
)
if (
reranking_enable
and reranking_mode == "weighted_score"
and (not embedding_model_check or not embedding_model_provider_check)
):
raise ValueError(
"The configured knowledge base list have different embedding model, please set reranking model."
)
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(