From 2d690801d19257a0c3dd01a56c4ee105dabe2826 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Tue, 10 Sep 2024 13:17:48 +0800 Subject: [PATCH] nvidia rerank top n missed (#8185) --- .../model_runtime/model_providers/nvidia/rerank/rerank.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py index 9d33f55bc2..80c24b0555 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py @@ -54,7 +54,6 @@ class NvidiaRerankModel(RerankModel): "query": {"text": query}, "passages": [{"text": doc} for doc in docs], } - session = requests.Session() response = session.post(invoke_url, headers=headers, json=payload) response.raise_for_status() @@ -71,7 +70,10 @@ class NvidiaRerankModel(RerankModel): ) rerank_documents.append(rerank_document) - + if rerank_documents: + rerank_documents = sorted(rerank_documents, key=lambda x: x.score, reverse=True) + if top_n: + rerank_documents = rerank_documents[:top_n] return RerankResult(model=model, docs=rerank_documents) except requests.HTTPError as e: raise InvokeServerUnavailableError(str(e))