Fix/multi thread parameter (#1604)

This commit is contained in:
Jyong 2023-11-22 18:31:29 +08:00 committed by GitHub
parent f704094a5f
commit a5b80c9d1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 8 deletions

View File

@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool):
'search_method'] == 'hybrid_search': 'search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': self.top_k, 'top_k': self.top_k,
'score_threshold': self.score_threshold, 'score_threshold': self.score_threshold,
@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool):
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
kwargs={ kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': 'hybrid_search', 'search_method': 'hybrid_search',
'embeddings': embeddings, 'embeddings': embeddings,

View File

@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool):
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': self.top_k, 'top_k': self.top_k,
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[ 'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool):
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,

View File

@ -61,7 +61,7 @@ class HitTestingService:
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={ embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'top_k': retrieval_model['top_k'], 'top_k': retrieval_model['top_k'],
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None, 'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
@ -77,7 +77,7 @@ class HitTestingService:
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search': if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={ full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
'flask_app': current_app._get_current_object(), 'flask_app': current_app._get_current_object(),
'dataset': dataset, 'dataset_id': str(dataset.id),
'query': query, 'query': query,
'search_method': retrieval_model['search_method'], 'search_method': retrieval_model['search_method'],
'embeddings': embeddings, 'embeddings': embeddings,

View File

@ -4,6 +4,7 @@ from flask import current_app, Flask
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from core.index.vector_index.vector_index import VectorIndex from core.index.vector_index.vector_index import VectorIndex
from core.model_providers.model_factory import ModelFactory from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
default_retrieval_model = { default_retrieval_model = {
@ -21,10 +22,13 @@ default_retrieval_model = {
class RetrievalService: class RetrievalService:
@classmethod @classmethod
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str, def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector_index = VectorIndex( vector_index = VectorIndex(
dataset=dataset, dataset=dataset,
@ -56,10 +60,13 @@ class RetrievalService:
all_documents.extend(documents) all_documents.extend(documents)
@classmethod @classmethod
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str, def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict], top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
all_documents: list, search_method: str, embeddings: Embeddings): all_documents: list, search_method: str, embeddings: Embeddings):
with flask_app.app_context(): with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
vector_index = VectorIndex( vector_index = VectorIndex(
dataset=dataset, dataset=dataset,