Fix/agent external knowledge retrieval (#9241)

This commit is contained in:
Jyong 2024-10-11 19:21:03 +08:00 committed by GitHub
parent 44f6a536d2
commit 42b02b3a5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 225 additions and 92 deletions

View File

@ -191,6 +191,22 @@ class CeleryConfig(DatabaseConfig):
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False
class InternalTestConfig(BaseSettings):
"""
Configuration settings for Internal Test
"""
AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="Internal test AWS secret access key",
default=None,
)
AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="Internal test AWS access key ID",
default=None,
)
class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
@ -224,5 +240,6 @@ class MiddlewareConfig(
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
InternalTestConfig,
):
pass

View File

@ -13,6 +13,7 @@ from libs.login import login_required
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService
def _validate_name(name):
@ -232,8 +233,31 @@ class ExternalKnowledgeHitTestingApi(Resource):
raise InternalServerError(str(e))
class BedrockRetrievalApi(Resource):
# this api is only for internal testing
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args()
# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200
api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
# this api is only for internal test
api.add_resource(BedrockRetrievalApi, "/test/retrieval")

View File

@ -539,7 +539,7 @@ class DatasetRetrieval:
continue
# pass if dataset is not available
if dataset and dataset.available_document_count == 0:
if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
continue
available_datasets.append(dataset)

View File

@ -1,10 +1,12 @@
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -53,97 +55,137 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
if dataset.provider == "external":
results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
)
return str("\n".join([document.page_content for document in documents]))
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
for external_document in external_documents:
document = RetrievalDocument(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
else:
documents = []
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
hit_callback.return_retriever_resource_info(context_list)
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
return str("\n".join([item.page_content for item in results]))
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
return str("\n".join([document.page_content for document in documents]))
else:
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
else:
documents = []
for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
context = {}
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
if segment.answer:
document_context_list.append(
f"question:{segment.get_sign_content()} answer:{segment.answer}"
)
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return str("\n".join(document_context_list))

View File

@ -79,8 +79,9 @@ class KnowledgeRetrievalNode(BaseNode):
results = (
db.session.query(Dataset)
.join(subquery, Dataset.id == subquery.c.dataset_id)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
@ -121,10 +122,13 @@ class KnowledgeRetrievalNode(BaseNode):
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
reranking_model = None

View File

@ -234,6 +234,7 @@ class DatasetService:
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", "")
external_knowledge_id = data.get("external_knowledge_id", None)
dataset.permission = data.get("permission")
db.session.add(dataset)
if not external_knowledge_id:
raise ValueError("External knowledge id is required.")

View File

@ -0,0 +1,45 @@
import boto3
from configs import dify_config
class ExternalDatasetTestService:
# this service is only for internal testing
@staticmethod
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
# get bedrock client
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
# example: us-east-1
region_name="us-east-1",
)
# fetch external knowledge retrieval
response = client.retrieve(
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": retrieval_setting.get("top_k"),
"overrideSearchType": "HYBRID",
}
},
retrievalQuery={"text": query},
)
# parse response
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
# filter out results with score less than threshold
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {"records": results}