mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Fix/multi thread parameter (#1604)
This commit is contained in:
parent
f704094a5f
commit
a5b80c9d1f
|
@ -192,7 +192,7 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||||
'search_method'] == 'hybrid_search':
|
'search_method'] == 'hybrid_search':
|
||||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'dataset': dataset,
|
'dataset_id': str(dataset.id),
|
||||||
'query': query,
|
'query': query,
|
||||||
'top_k': self.top_k,
|
'top_k': self.top_k,
|
||||||
'score_threshold': self.score_threshold,
|
'score_threshold': self.score_threshold,
|
||||||
|
@ -210,7 +210,7 @@ class DatasetMultiRetrieverTool(BaseTool):
|
||||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
|
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search,
|
||||||
kwargs={
|
kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'dataset': dataset,
|
'dataset_id': str(dataset.id),
|
||||||
'query': query,
|
'query': query,
|
||||||
'search_method': 'hybrid_search',
|
'search_method': 'hybrid_search',
|
||||||
'embeddings': embeddings,
|
'embeddings': embeddings,
|
||||||
|
|
|
@ -106,7 +106,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'dataset': dataset,
|
'dataset_id': str(dataset.id),
|
||||||
'query': query,
|
'query': query,
|
||||||
'top_k': self.top_k,
|
'top_k': self.top_k,
|
||||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
'score_threshold': retrieval_model['score_threshold'] if retrieval_model[
|
||||||
|
@ -124,7 +124,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'dataset': dataset,
|
'dataset_id': str(dataset.id),
|
||||||
'query': query,
|
'query': query,
|
||||||
'search_method': retrieval_model['search_method'],
|
'search_method': retrieval_model['search_method'],
|
||||||
'embeddings': embeddings,
|
'embeddings': embeddings,
|
||||||
|
|
|
@ -61,7 +61,7 @@ class HitTestingService:
|
||||||
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
if retrieval_model['search_method'] == 'semantic_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||||
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
embedding_thread = threading.Thread(target=RetrievalService.embedding_search, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'dataset': dataset,
|
'dataset_id': str(dataset.id),
|
||||||
'query': query,
|
'query': query,
|
||||||
'top_k': retrieval_model['top_k'],
|
'top_k': retrieval_model['top_k'],
|
||||||
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
'score_threshold': retrieval_model['score_threshold'] if retrieval_model['score_threshold_enable'] else None,
|
||||||
|
@ -77,7 +77,7 @@ class HitTestingService:
|
||||||
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
if retrieval_model['search_method'] == 'full_text_search' or retrieval_model['search_method'] == 'hybrid_search':
|
||||||
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
full_text_index_thread = threading.Thread(target=RetrievalService.full_text_index_search, kwargs={
|
||||||
'flask_app': current_app._get_current_object(),
|
'flask_app': current_app._get_current_object(),
|
||||||
'dataset': dataset,
|
'dataset_id': str(dataset.id),
|
||||||
'query': query,
|
'query': query,
|
||||||
'search_method': retrieval_model['search_method'],
|
'search_method': retrieval_model['search_method'],
|
||||||
'embeddings': embeddings,
|
'embeddings': embeddings,
|
||||||
|
|
|
@ -4,6 +4,7 @@ from flask import current_app, Flask
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from core.index.vector_index.vector_index import VectorIndex
|
from core.index.vector_index.vector_index import VectorIndex
|
||||||
from core.model_providers.model_factory import ModelFactory
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
from extensions.ext_database import db
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
|
|
||||||
default_retrieval_model = {
|
default_retrieval_model = {
|
||||||
|
@ -21,10 +22,13 @@ default_retrieval_model = {
|
||||||
class RetrievalService:
|
class RetrievalService:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def embedding_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
def embedding_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
vector_index = VectorIndex(
|
vector_index = VectorIndex(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
@ -56,10 +60,13 @@ class RetrievalService:
|
||||||
all_documents.extend(documents)
|
all_documents.extend(documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def full_text_index_search(cls, flask_app: Flask, dataset: Dataset, query: str,
|
def full_text_index_search(cls, flask_app: Flask, dataset_id: str, query: str,
|
||||||
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
top_k: int, score_threshold: Optional[float], reranking_model: Optional[dict],
|
||||||
all_documents: list, search_method: str, embeddings: Embeddings):
|
all_documents: list, search_method: str, embeddings: Embeddings):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
|
dataset = db.session.query(Dataset).filter(
|
||||||
|
Dataset.id == dataset_id
|
||||||
|
).first()
|
||||||
|
|
||||||
vector_index = VectorIndex(
|
vector_index = VectorIndex(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user