From a5b80c9d1fd07a30b9800a38b413b918cc81105c Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Wed, 22 Nov 2023 18:31:29 +0800 Subject: [PATCH] Fix/multi thread parameter (#1604) --- api/core/tool/dataset_multi_retriever_tool.py | 4 ++-- api/core/tool/dataset_retriever_tool.py | 4 ++-- api/services/hit_testing_service.py | 4 ++-- api/services/retrieval_service.py | 11 +++++++++-- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/api/core/tool/dataset_multi_retriever_tool.py b/api/core/tool/dataset_multi_retriever_tool.py index 07174b1d71..5cf120b63b 100644 --- a/api/core/tool/dataset_multi_retriever_tool.py +++ b/api/core/tool/dataset_multi_retriever_tool.py @@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool): 'search_method'] == 'hybrid_search': embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'top_k': self.top_k, 'score_threshold': self.score_threshold, @@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool): full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'search_method': 'hybrid_search', 'embeddings': embeddings, diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index cc8b8e1386..822a6562be 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool): if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'top_k': self.top_k, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ @@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool): if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'search_method': retrieval_model['search_method'], 'embeddings': embeddings, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index d9725a66d8..831a37d670 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -61,7 +61,7 @@ class HitTestingService: if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'top_k': retrieval_model['top_k'], 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, @@ -77,7 +77,7 @@ class HitTestingService: if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ 'flask_app': current_app._get_current_object(), - 'dataset': dataset, + 'dataset_id': str(dataset.id), 'query': query, 'search_method': retrieval_model['search_method'], 'embeddings': embeddings, diff --git a/api/services/retrieval_service.py b/api/services/retrieval_service.py index 3e6b93f862..f12533f2b0 100644 --- a/api/services/retrieval_service.py +++ b/api/services/retrieval_service.py @@ -4,6 +4,7 @@ from flask import current_app, Flask from langchain.embeddings.base import Embeddings from core.index.vector_index.vector_index import VectorIndex from core.model_providers.model_factory import ModelFactory +from extensions.ext_database import db from models.dataset import Dataset default_retrieval_model = { @@ -21,10 +22,13 @@ default_retrieval_model = { class RetrievalService: @classmethod - def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, + def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], all_documents: list, search_method: str, embeddings: Embeddings): with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() vector_index = VectorIndex( dataset=dataset, @@ -56,10 +60,13 @@ class RetrievalService: all_documents.extend(documents) @classmethod - def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, + def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], all_documents: list, search_method: str, embeddings: Embeddings): with flask_app.app_context(): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() vector_index = VectorIndex( dataset=dataset,