mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
add rerank check when doing mutil-retrieval (#9998)
This commit is contained in:
parent
5ad5d0cff4
commit
9ebd453b87
|
@ -1,6 +1,6 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class RerankMode(Enum):
|
||||
class RerankMode(str, Enum):
|
||||
RERANKING_MODEL = "reranking_model"
|
||||
WEIGHTED_SCORE = "weighted_score"
|
||||
|
|
|
@ -22,6 +22,7 @@ from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaK
|
|||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.context_entities import DocumentContext
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.rerank.rerank_type import RerankMode
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter
|
||||
from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter
|
||||
|
@ -361,10 +362,39 @@ class DatasetRetrieval:
|
|||
reranking_enable: bool = True,
|
||||
message_id: Optional[str] = None,
|
||||
):
|
||||
if not available_datasets:
|
||||
return []
|
||||
threads = []
|
||||
all_documents = []
|
||||
dataset_ids = [dataset.id for dataset in available_datasets]
|
||||
index_type = None
|
||||
index_type_check = all(
|
||||
item.indexing_technique == available_datasets[0].indexing_technique for item in available_datasets
|
||||
)
|
||||
if not index_type_check and (not reranking_enable or reranking_mode != RerankMode.RERANKING_MODEL):
|
||||
raise ValueError(
|
||||
"The configured knowledge base list have different indexing technique, please set reranking model."
|
||||
)
|
||||
index_type = available_datasets[0].indexing_technique
|
||||
if index_type == "high_quality":
|
||||
embedding_model_check = all(
|
||||
item.embedding_model == available_datasets[0].embedding_model for item in available_datasets
|
||||
)
|
||||
embedding_model_provider_check = all(
|
||||
item.embedding_model_provider == available_datasets[0].embedding_model_provider
|
||||
for item in available_datasets
|
||||
)
|
||||
if (
|
||||
reranking_enable
|
||||
and reranking_mode == "weighted_score"
|
||||
and (not embedding_model_check or not embedding_model_provider_check)
|
||||
):
|
||||
raise ValueError(
|
||||
"The configured knowledge base list have different embedding model, please set reranking model."
|
||||
)
|
||||
if reranking_enable and reranking_mode == RerankMode.WEIGHTED_SCORE:
|
||||
weights["vector_setting"]["embedding_provider_name"] = available_datasets[0].embedding_model_provider
|
||||
weights["vector_setting"]["embedding_model_name"] = available_datasets[0].embedding_model
|
||||
|
||||
for dataset in available_datasets:
|
||||
index_type = dataset.indexing_technique
|
||||
retrieval_thread = threading.Thread(
|
||||
|
|
Loading…
Reference in New Issue
Block a user