mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: add voyage ai as a new model provider (#8747)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
This commit is contained in:
parent
42dfde6546
commit
fb49413a41
|
@ -40,3 +40,4 @@
|
||||||
- fireworks
|
- fireworks
|
||||||
- mixedbread
|
- mixedbread
|
||||||
- nomic
|
- nomic
|
||||||
|
- voyage
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
<svg version="1.0" xmlns="http://www.w3.org/2000/svg" width="100.000000pt" height="19.000000pt" viewBox="0 0 300.000000 57.000000" preserveAspectRatio="xMidYMid meet"><g transform="translate(0.000000,57.000000) scale(0.100000,-0.100000)" fill="#000000" stroke="none"><path d="M2505 368 c-38 -84 -86 -188 -106 -230 l-38 -78 27 0 c24 0 30 7 55
|
||||||
|
75 l28 75 100 0 100 0 25 -55 c13 -31 24 -64 24 -75 0 -17 7 -20 44 -20 l43 0
|
||||||
|
-37 73 c-20 39 -68 143 -106 229 -38 87 -74 158 -80 158 -5 0 -41 -69 -79
|
||||||
|
-152z m110 -30 c22 -51 41 -95 42 -98 2 -3 -36 -6 -83 -7 -76 -1 -85 0 -81 15
|
||||||
|
12 40 72 182 77 182 3 0 24 -41 45 -92z"/><path d="M63 493 c19 -61 197 -438 209 -440 10 -2 147 282 216 449 2 4 -10 8
|
||||||
|
-27 8 -23 0 -31 -5 -31 -17 0 -16 -142 -365 -146 -360 -8 11 -144 329 -149
|
||||||
|
350 -6 23 -12 27 -42 27 -29 0 -34 -3 -30 -17z"/><path d="M2855 285 l0 -225 30 0 30 0 0 225 0 225 -30 0 -30 0 0 -225z"/><path d="M588 380 c-55 -30 -82 -74 -86 -145 -3 -50 0 -66 20 -95 39 -58 82
|
||||||
|
-80 153 -80 68 0 110 21 149 73 32 43 30 150 -3 196 -47 66 -158 90 -233 51z
|
||||||
|
m133 -16 c59 -30 89 -156 54 -224 -45 -87 -162 -78 -201 16 -18 44 -18 128 1
|
||||||
|
164 28 55 90 73 146 44z"/><path d="M935 303 l76 -98 -7 -72 -6 -73 33 0 34 0 -3 78 -4 77 71 93 c65 85
|
||||||
|
68 92 46 92 -15 0 -29 -9 -36 -22 -18 -33 -90 -128 -98 -128 -6 1 -67 85 -88
|
||||||
|
122 -8 15 -24 23 -53 25 l-41 4 76 -98z"/><path d="M1257 230 c-82 -169 -83 -170 -57 -170 17 0 27 6 27 15 0 8 7 31 17
|
||||||
|
52 l17 38 79 0 78 1 16 -34 c9 -18 16 -42 16 -52 0 -17 7 -20 41 -20 22 0 39
|
||||||
|
3 37 8 -2 4 -39 80 -83 170 -43 89 -84 162 -92 162 -7 0 -50 -76 -96 -170z
|
||||||
|
m90 -38 c-33 -2 -61 -1 -63 1 -2 2 10 34 26 71 l31 68 33 -68 33 -69 -60 -3z"/><path d="M1665 386 c-37 -16 -84 -63 -97 -96 -13 -35 -12 -104 2 -132 49 -94
|
||||||
|
182 -134 280 -83 24 12 29 22 32 64 3 49 3 49 -30 53 l-33 4 3 -45 c4 -61 -5
|
||||||
|
-71 -60 -71 -93 0 -142 57 -142 164 0 44 5 60 25 85 47 55 136 65 184 20 30
|
||||||
|
-28 35 -20 11 19 -19 31 -22 32 -82 32 -35 -1 -76 -7 -93 -14z"/><path d="M1955 230 l0 -170 91 0 c76 0 93 3 98 16 4 9 5 18 4 20 -2 1 -31 -1
|
||||||
|
-66 -5 -34 -4 -64 -5 -67 -3 -3 3 -5 36 -5 73 l0 68 55 -6 c49 -5 55 -4 55 13
|
||||||
|
0 17 -6 19 -55 16 l-55 -4 0 61 0 61 64 0 c48 0 65 4 70 15 4 13 -10 15 -92
|
||||||
|
15 l-97 0 0 -170z"/></g></svg>
|
After Width: | Height: | Size: 2.2 KiB |
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="64px" height="64px" viewBox="0 0 64 64" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>voyage</title>
|
||||||
|
<g id="voyage" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<rect id="矩形" fill="#333333" x="0" y="0" width="64" height="64" rx="12"></rect>
|
||||||
|
<path d="M12.1128004,51.4376727 C13.8950799,45.8316747 30.5922254,11.1847688 31.7178757,11.0009656 C32.6559176,10.8171624 45.5070913,36.9172188 51.9795803,52.2647871 C52.1671887,52.6323936 51.0415384,53 49.4468672,53 C47.2893709,53 46.5389374,52.540492 46.5389374,51.4376727 C46.5389374,49.967247 33.2187427,17.8935861 32.8435259,18.3530942 C32.0930924,19.3640118 19.3357228,48.5887229 18.8667019,50.5186566 C18.3038768,52.6323936 17.7410516,53 14.926926,53 C12.2066045,53 11.7375836,52.7242952 12.1128004,51.4376727 Z" id="路径" fill="#FFFFFF" transform="translate(32, 32) scale(1, -1) translate(-32, -32)"></path>
|
||||||
|
</g>
|
||||||
|
</svg>
|
After Width: | Height: | Size: 1.0 KiB |
|
@ -0,0 +1,4 @@
|
||||||
|
model: rerank-1
|
||||||
|
model_type: rerank
|
||||||
|
model_properties:
|
||||||
|
context_size: 8000
|
|
@ -0,0 +1,4 @@
|
||||||
|
model: rerank-lite-1
|
||||||
|
model_type: rerank
|
||||||
|
model_properties:
|
||||||
|
context_size: 4000
|
123
api/core/model_runtime/model_providers/voyage/rerank/rerank.py
Normal file
123
api/core/model_runtime/model_providers/voyage/rerank/rerank.py
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.rerank_model import RerankModel
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageRerankModel(RerankModel):
|
||||||
|
"""
|
||||||
|
Model class for Voyage rerank model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
query: str,
|
||||||
|
docs: list[str],
|
||||||
|
score_threshold: Optional[float] = None,
|
||||||
|
top_n: Optional[int] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> RerankResult:
|
||||||
|
"""
|
||||||
|
Invoke rerank model
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param query: search query
|
||||||
|
:param docs: docs for reranking
|
||||||
|
:param score_threshold: score threshold
|
||||||
|
:param top_n: top n documents to return
|
||||||
|
:param user: unique user id
|
||||||
|
:return: rerank result
|
||||||
|
"""
|
||||||
|
if len(docs) == 0:
|
||||||
|
return RerankResult(model=model, docs=[])
|
||||||
|
|
||||||
|
base_url = credentials.get("base_url", "https://api.voyageai.com/v1")
|
||||||
|
base_url = base_url.removesuffix("/")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = httpx.post(
|
||||||
|
base_url + "/rerank",
|
||||||
|
json={"model": model, "query": query, "documents": docs, "top_k": top_n, "return_documents": True},
|
||||||
|
headers={"Authorization": f"Bearer {credentials.get('api_key')}", "Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
results = response.json()
|
||||||
|
|
||||||
|
rerank_documents = []
|
||||||
|
for result in results["data"]:
|
||||||
|
rerank_document = RerankDocument(
|
||||||
|
index=result["index"],
|
||||||
|
text=result["document"],
|
||||||
|
score=result["relevance_score"],
|
||||||
|
)
|
||||||
|
if score_threshold is None or result["relevance_score"] >= score_threshold:
|
||||||
|
rerank_documents.append(rerank_document)
|
||||||
|
|
||||||
|
return RerankResult(model=model, docs=rerank_documents)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise InvokeServerUnavailableError(str(e))
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(
|
||||||
|
model=model,
|
||||||
|
credentials=credentials,
|
||||||
|
query="What is the capital of the United States?",
|
||||||
|
docs=[
|
||||||
|
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
|
||||||
|
"Census, Carson City had a population of 55,274.",
|
||||||
|
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
|
||||||
|
"are a political division controlled by the United States. Its capital is Saipan.",
|
||||||
|
],
|
||||||
|
score_threshold=0.8,
|
||||||
|
)
|
||||||
|
except Exception as ex:
|
||||||
|
raise CredentialsValidateFailedError(str(ex))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
"""
|
||||||
|
Map model invoke error to unified error
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [httpx.ConnectError],
|
||||||
|
InvokeServerUnavailableError: [httpx.RemoteProtocolError],
|
||||||
|
InvokeRateLimitError: [],
|
||||||
|
InvokeAuthorizationError: [httpx.HTTPStatusError],
|
||||||
|
InvokeBadRequestError: [httpx.RequestError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.RERANK,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", "8000"))},
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
|
@ -0,0 +1,172 @@
|
||||||
|
import time
|
||||||
|
from json import JSONDecodeError, dumps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from core.embedding.embedding_constant import EmbeddingInputType
|
||||||
|
from core.model_runtime.entities.common_entities import I18nObject
|
||||||
|
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.invoke import (
|
||||||
|
InvokeAuthorizationError,
|
||||||
|
InvokeBadRequestError,
|
||||||
|
InvokeConnectionError,
|
||||||
|
InvokeError,
|
||||||
|
InvokeRateLimitError,
|
||||||
|
InvokeServerUnavailableError,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageTextEmbeddingModel(TextEmbeddingModel):
|
||||||
|
"""
|
||||||
|
Model class for Voyage text embedding model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_base: str = "https://api.voyageai.com/v1"
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
credentials: dict,
|
||||||
|
texts: list[str],
|
||||||
|
user: Optional[str] = None,
|
||||||
|
input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT,
|
||||||
|
) -> TextEmbeddingResult:
|
||||||
|
"""
|
||||||
|
Invoke text embedding model
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:param user: unique user id
|
||||||
|
:param input_type: input type
|
||||||
|
:return: embeddings result
|
||||||
|
"""
|
||||||
|
api_key = credentials["api_key"]
|
||||||
|
if not api_key:
|
||||||
|
raise CredentialsValidateFailedError("api_key is required")
|
||||||
|
|
||||||
|
base_url = credentials.get("base_url", self.api_base)
|
||||||
|
base_url = base_url.removesuffix("/")
|
||||||
|
|
||||||
|
url = base_url + "/embeddings"
|
||||||
|
headers = {"Authorization": "Bearer " + api_key, "Content-Type": "application/json"}
|
||||||
|
voyage_input_type = "null"
|
||||||
|
if input_type is not None:
|
||||||
|
voyage_input_type = input_type.value
|
||||||
|
data = {"model": model, "input": texts, "input_type": voyage_input_type}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, data=dumps(data))
|
||||||
|
except Exception as e:
|
||||||
|
raise InvokeConnectionError(str(e))
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
try:
|
||||||
|
resp = response.json()
|
||||||
|
msg = resp["detail"]
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise InvokeAuthorizationError(msg)
|
||||||
|
elif response.status_code == 429:
|
||||||
|
raise InvokeRateLimitError(msg)
|
||||||
|
elif response.status_code == 500:
|
||||||
|
raise InvokeServerUnavailableError(msg)
|
||||||
|
else:
|
||||||
|
raise InvokeBadRequestError(msg)
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
raise InvokeServerUnavailableError(
|
||||||
|
f"Failed to convert response to json: {e} with text: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = response.json()
|
||||||
|
embeddings = resp["data"]
|
||||||
|
usage = resp["usage"]
|
||||||
|
except Exception as e:
|
||||||
|
raise InvokeServerUnavailableError(f"Failed to convert response to json: {e} with text: {response.text}")
|
||||||
|
|
||||||
|
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=usage["total_tokens"])
|
||||||
|
|
||||||
|
result = TextEmbeddingResult(
|
||||||
|
model=model, embeddings=[[float(data) for data in x["embedding"]] for x in embeddings], usage=usage
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Get number of tokens for given prompt messages
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param texts: texts to embed
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||||
|
|
||||||
|
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate model credentials
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._invoke(model=model, credentials=credentials, texts=["ping"])
|
||||||
|
except Exception as e:
|
||||||
|
raise CredentialsValidateFailedError(f"Credentials validation failed: {e}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||||
|
return {
|
||||||
|
InvokeConnectionError: [InvokeConnectionError],
|
||||||
|
InvokeServerUnavailableError: [InvokeServerUnavailableError],
|
||||||
|
InvokeRateLimitError: [InvokeRateLimitError],
|
||||||
|
InvokeAuthorizationError: [InvokeAuthorizationError],
|
||||||
|
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||||
|
"""
|
||||||
|
Calculate response usage
|
||||||
|
|
||||||
|
:param model: model name
|
||||||
|
:param credentials: model credentials
|
||||||
|
:param tokens: input tokens
|
||||||
|
:return: usage
|
||||||
|
"""
|
||||||
|
# get input price info
|
||||||
|
input_price_info = self.get_price(
|
||||||
|
model=model, credentials=credentials, price_type=PriceType.INPUT, tokens=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# transform usage
|
||||||
|
usage = EmbeddingUsage(
|
||||||
|
tokens=tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
unit_price=input_price_info.unit_price,
|
||||||
|
price_unit=input_price_info.unit,
|
||||||
|
total_price=input_price_info.total_amount,
|
||||||
|
currency=input_price_info.currency,
|
||||||
|
latency=time.perf_counter() - self.started_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
return usage
|
||||||
|
|
||||||
|
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||||
|
"""
|
||||||
|
generate custom model entities from credentials
|
||||||
|
"""
|
||||||
|
entity = AIModelEntity(
|
||||||
|
model=model,
|
||||||
|
label=I18nObject(en_US=model),
|
||||||
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
|
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||||
|
model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))},
|
||||||
|
)
|
||||||
|
|
||||||
|
return entity
|
|
@ -0,0 +1,8 @@
|
||||||
|
model: voyage-3-lite
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 32000
|
||||||
|
pricing:
|
||||||
|
input: '0.00002'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,8 @@
|
||||||
|
model: voyage-3
|
||||||
|
model_type: text-embedding
|
||||||
|
model_properties:
|
||||||
|
context_size: 32000
|
||||||
|
pricing:
|
||||||
|
input: '0.00006'
|
||||||
|
unit: '0.001'
|
||||||
|
currency: USD
|
28
api/core/model_runtime/model_providers/voyage/voyage.py
Normal file
28
api/core/model_runtime/model_providers/voyage/voyage.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from core.model_runtime.entities.model_entities import ModelType
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class VoyageProvider(ModelProvider):
|
||||||
|
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate provider credentials
|
||||||
|
if validate failed, raise exception
|
||||||
|
|
||||||
|
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
|
||||||
|
|
||||||
|
# Use `voyage-3` model for validate,
|
||||||
|
# no matter what model you pass in, text completion model or chat model
|
||||||
|
model_instance.validate_credentials(model="voyage-3", credentials=credentials)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ex
|
||||||
|
except Exception as ex:
|
||||||
|
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||||
|
raise ex
|
31
api/core/model_runtime/model_providers/voyage/voyage.yaml
Normal file
31
api/core/model_runtime/model_providers/voyage/voyage.yaml
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
provider: voyage
|
||||||
|
label:
|
||||||
|
en_US: Voyage
|
||||||
|
description:
|
||||||
|
en_US: Embedding and Rerank Model Supported
|
||||||
|
icon_small:
|
||||||
|
en_US: icon_s_en.svg
|
||||||
|
icon_large:
|
||||||
|
en_US: icon_l_en.svg
|
||||||
|
background: "#EFFDFD"
|
||||||
|
help:
|
||||||
|
title:
|
||||||
|
en_US: Get your API key from Voyage AI
|
||||||
|
zh_Hans: 从 Voyage 获取 API Key
|
||||||
|
url:
|
||||||
|
en_US: https://dash.voyageai.com/
|
||||||
|
supported_model_types:
|
||||||
|
- text-embedding
|
||||||
|
- rerank
|
||||||
|
configurate_methods:
|
||||||
|
- predefined-model
|
||||||
|
provider_credential_schema:
|
||||||
|
credential_form_schemas:
|
||||||
|
- variable: api_key
|
||||||
|
label:
|
||||||
|
en_US: API Key
|
||||||
|
type: secret-input
|
||||||
|
required: true
|
||||||
|
placeholder:
|
||||||
|
zh_Hans: 在此输入您的 API Key
|
||||||
|
en_US: Enter your API Key
|
|
@ -123,6 +123,7 @@ FIRECRAWL_API_KEY = "fc-"
|
||||||
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
|
TEI_EMBEDDING_SERVER_URL = "http://a.abc.com:11451"
|
||||||
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
|
TEI_RERANK_SERVER_URL = "http://a.abc.com:11451"
|
||||||
MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa"
|
MIXEDBREAD_API_KEY = "mk-aaaaaaaaaaaaaaaaaaaa"
|
||||||
|
VOYAGE_API_KEY = "va-aaaaaaaaaaaaaaaaaaaa"
|
||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.voyage.voyage import VoyageProvider
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_provider_credentials():
|
||||||
|
provider = VoyageProvider()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
provider.validate_provider_credentials(credentials={"api_key": "hahahaha"})
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
|
||||||
|
"model": "voyage-3",
|
||||||
|
"usage": {"total_tokens": 1},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
provider.validate_provider_credentials(credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
|
|
@ -0,0 +1,92 @@
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.voyage.rerank.rerank import VoyageRerankModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = VoyageRerankModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(
|
||||||
|
model="rerank-lite-1",
|
||||||
|
credentials={"api_key": "invalid_key"},
|
||||||
|
)
|
||||||
|
with patch("httpx.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"relevance_score": 0.546875,
|
||||||
|
"index": 0,
|
||||||
|
"document": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
|
||||||
|
"States Census, Carson City had a population of 55,274.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"relevance_score": 0.4765625,
|
||||||
|
"index": 1,
|
||||||
|
"document": "The Commonwealth of the Northern Mariana Islands is a group of islands in the "
|
||||||
|
"Pacific Ocean that are a political division controlled by the United States. Its "
|
||||||
|
"capital is Saipan.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"model": "rerank-lite-1",
|
||||||
|
"usage": {"total_tokens": 96},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
model.validate_credentials(
|
||||||
|
model="rerank-lite-1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = VoyageRerankModel()
|
||||||
|
with patch("httpx.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"relevance_score": 0.84375,
|
||||||
|
"index": 0,
|
||||||
|
"document": "Kasumi is a girl name of Japanese origin meaning mist.",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"relevance_score": 0.4765625,
|
||||||
|
"index": 1,
|
||||||
|
"document": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she "
|
||||||
|
"leads a team named PopiParty.",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"model": "rerank-lite-1",
|
||||||
|
"usage": {"total_tokens": 59},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
result = model.invoke(
|
||||||
|
model="rerank-lite-1",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
query="Who is Kasumi?",
|
||||||
|
docs=[
|
||||||
|
"Kasumi is a girl name of Japanese origin meaning mist.",
|
||||||
|
"Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
|
||||||
|
"PopiParty.",
|
||||||
|
],
|
||||||
|
score_threshold=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, RerankResult)
|
||||||
|
assert len(result.docs) == 1
|
||||||
|
assert result.docs[0].index == 0
|
||||||
|
assert result.docs[0].score >= 0.5
|
|
@ -0,0 +1,70 @@
|
||||||
|
import os
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.model_runtime.model_providers.voyage.text_embedding.text_embedding import VoyageTextEmbeddingModel
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_credentials():
|
||||||
|
model = VoyageTextEmbeddingModel()
|
||||||
|
|
||||||
|
with pytest.raises(CredentialsValidateFailedError):
|
||||||
|
model.validate_credentials(model="voyage-3", credentials={"api_key": "invalid_key"})
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0}],
|
||||||
|
"model": "voyage-3",
|
||||||
|
"usage": {"total_tokens": 1},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
model.validate_credentials(model="voyage-3", credentials={"api_key": os.environ.get("VOYAGE_API_KEY")})
|
||||||
|
|
||||||
|
|
||||||
|
def test_invoke_model():
|
||||||
|
model = VoyageTextEmbeddingModel()
|
||||||
|
|
||||||
|
with patch("requests.post") as mock_post:
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 0},
|
||||||
|
{"object": "embedding", "embedding": [0.23333 for _ in range(1024)], "index": 1},
|
||||||
|
],
|
||||||
|
"model": "voyage-3",
|
||||||
|
"usage": {"total_tokens": 2},
|
||||||
|
}
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
result = model.invoke(
|
||||||
|
model="voyage-3",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["hello", "world"],
|
||||||
|
user="abc-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, TextEmbeddingResult)
|
||||||
|
assert len(result.embeddings) == 2
|
||||||
|
assert result.usage.total_tokens == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens():
|
||||||
|
model = VoyageTextEmbeddingModel()
|
||||||
|
|
||||||
|
num_tokens = model.get_num_tokens(
|
||||||
|
model="voyage-3",
|
||||||
|
credentials={
|
||||||
|
"api_key": os.environ.get("VOYAGE_API_KEY"),
|
||||||
|
},
|
||||||
|
texts=["ping"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_tokens == 1
|
|
@ -9,4 +9,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
|
||||||
api/tests/integration_tests/model_runtime/upstage \
|
api/tests/integration_tests/model_runtime/upstage \
|
||||||
api/tests/integration_tests/model_runtime/fireworks \
|
api/tests/integration_tests/model_runtime/fireworks \
|
||||||
api/tests/integration_tests/model_runtime/nomic \
|
api/tests/integration_tests/model_runtime/nomic \
|
||||||
api/tests/integration_tests/model_runtime/mixedbread
|
api/tests/integration_tests/model_runtime/mixedbread \
|
||||||
|
api/tests/integration_tests/model_runtime/voyage
|
Loading…
Reference in New Issue
Block a user