feat:add wenxin rerank (#9431)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Mark stale issues and pull requests / stale (push) Has been cancelled

Co-authored-by: cuihz <cuihz@knowbox.cn>
Co-authored-by: crazywoola <427733928@qq.com>
This commit is contained in:
chzphoenix 2024-10-17 19:18:32 +08:00 committed by GitHub
parent b90ad587c2
commit 211f416806
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 178 additions and 0 deletions

View File

@ -120,6 +120,7 @@ class _CommonWenxin:
"bge-large-en": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_en",
"bge-large-zh": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/bge_large_zh",
"tao-8k": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/tao_8k",
"bce-reranker-base_v1": "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/reranker/bce_reranker_base",
}
function_calling_supports = [

View File

@ -0,0 +1,8 @@
model: bce-reranker-base_v1
model_type: rerank
model_properties:
context_size: 4096
pricing:
input: '0.0005'
unit: '0.001'
currency: RMB

View File

@ -0,0 +1,147 @@
from typing import Optional
import httpx
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
from core.model_runtime.model_providers.wenxin._common import _CommonWenxin
class WenxinRerank(_CommonWenxin):
def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None):
access_token = self._get_access_token()
url = f"{self.api_bases[model]}?access_token={access_token}"
try:
response = httpx.post(
url,
json={"model": model, "query": query, "documents": docs, "top_n": top_n},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
return response.json()
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
class WenxinRerankModel(RerankModel):
"""
Model class for wenxin rerank model.
"""
def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])
api_key = credentials["api_key"]
secret_key = credentials["secret_key"]
wenxin_rerank: WenxinRerank = WenxinRerank(api_key, secret_key)
try:
results = wenxin_rerank.rerank(model, query, docs, top_n)
rerank_documents = []
for result in results["results"]:
index = result["index"]
if "document" in result:
text = result["document"]
else:
# llama.cpp rerank maynot return original documents
text = docs[index]
rerank_document = RerankDocument(
index=index,
text=text,
score=result["relevance_score"],
)
if score_threshold is None or result["relevance_score"] >= score_threshold:
rerank_documents.append(rerank_document)
return RerankResult(model=model, docs=rerank_documents)
except httpx.HTTPStatusError as e:
raise InvokeServerUnavailableError(str(e))
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self._invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
"""
return {
InvokeConnectionError: [httpx.ConnectError],
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
InvokeRateLimitError: [],
InvokeAuthorizationError: [httpx.HTTPStatusError],
InvokeBadRequestError: [httpx.RequestError],
}
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
"""
generate custom model entities from credentials
"""
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.RERANK,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
)
return entity

View File

@ -18,6 +18,7 @@ help:
supported_model_types:
- llm
- text-embedding
- rerank
configurate_methods:
- predefined-model
provider_credential_schema:

View File

@ -0,0 +1,21 @@
import os
from time import sleep
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.wenxin.rerank.rerank import WenxinRerankModel
def test_invoke_bce_reranker_base_v1():
sleep(3)
model = WenxinRerankModel()
response = model.invoke(
model="bce-reranker-base_v1",
credentials={"api_key": os.environ.get("WENXIN_API_KEY"), "secret_key": os.environ.get("WENXIN_SECRET_KEY")},
query="What is Deep Learning?",
docs=["Deep Learning is ...", "My Book is ..."],
user="abc-123",
)
assert isinstance(response, RerankResult)
assert len(response.docs) == 2