2024-02-09 15:21:33 +08:00
|
|
|
from typing import Optional
|
2024-01-02 23:42:00 +08:00
|
|
|
|
2024-02-06 13:21:13 +08:00
|
|
|
from core.model_manager import ModelInstance
|
2024-02-23 14:16:44 +08:00
|
|
|
from core.rag.models.document import Document
|
2024-10-17 19:12:42 +08:00
|
|
|
from core.rag.rerank.rerank_base import BaseRerankRunner
|
2024-02-06 13:21:13 +08:00
|
|
|
|
2024-01-02 23:42:00 +08:00
|
|
|
|
2024-10-17 19:12:42 +08:00
|
|
|
class RerankModelRunner(BaseRerankRunner):
|
2024-01-02 23:42:00 +08:00
|
|
|
def __init__(self, rerank_model_instance: ModelInstance) -> None:
|
|
|
|
self.rerank_model_instance = rerank_model_instance
|
|
|
|
|
2024-09-10 17:00:20 +08:00
|
|
|
def run(
|
|
|
|
self,
|
|
|
|
query: str,
|
|
|
|
documents: list[Document],
|
|
|
|
score_threshold: Optional[float] = None,
|
|
|
|
top_n: Optional[int] = None,
|
|
|
|
user: Optional[str] = None,
|
|
|
|
) -> list[Document]:
|
2024-01-02 23:42:00 +08:00
|
|
|
"""
|
|
|
|
Run rerank model
|
|
|
|
:param query: search query
|
|
|
|
:param documents: documents for reranking
|
|
|
|
:param score_threshold: score threshold
|
|
|
|
:param top_n: top n
|
|
|
|
:param user: unique user id if needed
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
docs = []
|
2024-10-31 16:32:58 +08:00
|
|
|
doc_id = set()
|
2024-01-02 23:42:00 +08:00
|
|
|
unique_documents = []
|
2024-10-31 16:32:58 +08:00
|
|
|
for document in documents:
|
|
|
|
if document.provider == "dify" and document.metadata["doc_id"] not in doc_id:
|
|
|
|
doc_id.add(document.metadata["doc_id"])
|
2024-01-02 23:42:00 +08:00
|
|
|
docs.append(document.page_content)
|
|
|
|
unique_documents.append(document)
|
2024-10-31 16:32:58 +08:00
|
|
|
elif document.provider == "external":
|
|
|
|
if document not in unique_documents:
|
|
|
|
docs.append(document.page_content)
|
|
|
|
unique_documents.append(document)
|
2024-01-02 23:42:00 +08:00
|
|
|
|
|
|
|
documents = unique_documents
|
|
|
|
|
|
|
|
rerank_result = self.rerank_model_instance.invoke_rerank(
|
2024-09-10 17:00:20 +08:00
|
|
|
query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user
|
2024-01-02 23:42:00 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
rerank_documents = []
|
|
|
|
|
|
|
|
for result in rerank_result.docs:
|
|
|
|
# format document
|
|
|
|
rerank_document = Document(
|
|
|
|
page_content=result.text,
|
2024-09-30 15:38:43 +08:00
|
|
|
metadata=documents[result.index].metadata,
|
|
|
|
provider=documents[result.index].provider,
|
2024-01-02 23:42:00 +08:00
|
|
|
)
|
2024-09-30 15:38:43 +08:00
|
|
|
rerank_document.metadata["score"] = result.score
|
2024-01-02 23:42:00 +08:00
|
|
|
rerank_documents.append(rerank_document)
|
|
|
|
|
|
|
|
return rerank_documents
|