Fix/new RAG bugs (#2547)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong 2024-02-23 16:54:15 +08:00 committed by GitHub
parent 49da8a23a8
commit 4be3087642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 9 additions and 6 deletions

View File

@ -365,7 +365,7 @@ class IndexingRunner:
notion_info={ notion_info={
"notion_workspace_id": data_source_info['notion_workspace_id'], "notion_workspace_id": data_source_info['notion_workspace_id'],
"notion_obj_id": data_source_info['notion_page_id'], "notion_obj_id": data_source_info['notion_page_id'],
"notion_page_type": data_source_info['notion_page_type'], "notion_page_type": data_source_info['type'],
"document": dataset_document "document": dataset_document
}, },
document_model=dataset_document.doc_form document_model=dataset_document.doc_form

View File

@ -2,7 +2,6 @@ import threading
from typing import Optional from typing import Optional
from flask import Flask, current_app from flask import Flask, current_app
from flask_login import current_user
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
@ -27,6 +26,11 @@ class RetrievalService:
@classmethod @classmethod
def retrieve(cls, retrival_method: str, dataset_id: str, query: str, def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
return []
all_documents = [] all_documents = []
threads = [] threads = []
# retrieval_model source with keyword # retrieval_model source with keyword
@ -73,7 +77,7 @@ class RetrievalService:
thread.join() thread.join()
if retrival_method == 'hybrid_search': if retrival_method == 'hybrid_search':
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
all_documents = data_post_processor.invoke( all_documents = data_post_processor.invoke(
query=query, query=query,
documents=all_documents, documents=all_documents,

View File

@ -171,7 +171,7 @@ class DatasetMultiRetrieverTool(BaseTool):
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=self.top_k top_k=self.top_k

View File

@ -69,7 +69,7 @@ class DatasetRetrieverTool(BaseTool):
retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], documents = RetrievalService.retrieve(retrival_method='keyword_search',
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
top_k=self.top_k top_k=self.top_k

View File

@ -40,7 +40,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str,
indexing_technique=indexing_technique, indexing_technique=indexing_technique,
index_struct=index_struct, index_struct=index_struct,
collection_binding_id=collection_binding_id, collection_binding_id=collection_binding_id,
doc_form=doc_form
) )
documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()