diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index c8a2e09443..d2d04c984b 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -365,7 +365,7 @@ class IndexingRunner: notion_info={ "notion_workspace_id": data_source_info['notion_workspace_id'], "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['notion_page_type'], + "notion_page_type": data_source_info['type'], "document": dataset_document }, document_model=dataset_document.doc_form diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 79673ffa83..c0205d1aa9 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -2,7 +2,6 @@ import threading from typing import Optional from flask import Flask, current_app -from flask_login import current_user from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword @@ -27,6 +26,11 @@ class RetrievalService: @classmethod def retrieve(cls, retrival_method: str, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: + return [] all_documents = [] threads = [] # retrieval_model source with keyword @@ -73,7 +77,7 @@ class RetrievalService: thread.join() if retrival_method == 'hybrid_search': - data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents = data_post_processor.invoke( query=query, documents=all_documents, diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 57b6e090c4..d9934acff9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -171,7 +171,7 @@ class DatasetMultiRetrieverTool(BaseTool): if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + documents = RetrievalService.retrieve(retrival_method='keyword_search', dataset_id=dataset.id, query=query, top_k=self.top_k diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index d3ec0fba69..13331d981b 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -69,7 +69,7 @@ class DatasetRetrieverTool(BaseTool): retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + documents = RetrievalService.retrieve(retrival_method='keyword_search', dataset_id=dataset.id, query=query, top_k=self.top_k diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 16e4affc91..37e109c847 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -40,7 +40,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, indexing_technique=indexing_technique, index_struct=index_struct, collection_binding_id=collection_binding_id, - doc_form=doc_form ) documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()