Add nomic embedding model provider (#8640)

This commit is contained in:
ice yao 2024-09-23 19:57:21 +08:00 committed by GitHub
parent 4f69adc8ab
commit d7aada38a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 506 additions and 2 deletions

View File

@ -0,0 +1,13 @@
<svg width="93" height="31" viewBox="0 0 93 31" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M89.6037 29.888C88.9464 29.888 88.3667 29.7302 87.8647 29.4145C87.3626 29.0936 86.9719 28.6407 86.6924 28.0559L87.2979 27.4037C87.5464 27.9109 87.8698 28.3069 88.2684 28.5915C88.6669 28.871 89.1094 29.0108 89.5959 29.0108C89.922 29.0108 90.2196 28.9435 90.4887 28.8089C90.763 28.6744 90.9804 28.4829 91.1408 28.2344C91.3064 27.9808 91.3892 27.6806 91.3892 27.3339C91.3892 27.0182 91.3116 26.7697 91.1563 26.5886C91.0062 26.4074 90.7837 26.2522 90.4887 26.1228C90.1988 25.9882 89.8366 25.8381 89.4018 25.6725C89.0654 25.5379 88.7393 25.3853 88.4236 25.2145C88.1079 25.0437 87.8465 24.8289 87.6395 24.5701C87.4377 24.3061 87.3367 23.9723 87.3367 23.5686C87.3367 23.1598 87.4454 22.7975 87.6628 22.4817C87.8802 22.1609 88.1804 21.9098 88.5634 21.7287C88.9464 21.5424 89.3811 21.4492 89.8676 21.4492C90.3127 21.4492 90.7293 21.545 91.1175 21.7365C91.5109 21.928 91.8628 22.1997 92.1733 22.5516L91.6532 23.2115C91.177 22.5853 90.5844 22.2721 89.8754 22.2721C89.4406 22.2721 89.0861 22.386 88.8118 22.6137C88.5427 22.8415 88.4081 23.1391 88.4081 23.5065C88.4081 23.7705 88.4935 23.9904 88.6643 24.1664C88.8351 24.3424 89.0576 24.4925 89.3319 24.6167C89.6114 24.7409 89.9116 24.8651 90.2325 24.9893C90.6983 25.1653 91.102 25.3413 91.4436 25.5172C91.7903 25.6932 92.0595 25.9183 92.251 26.1927C92.4425 26.4618 92.5382 26.8293 92.5382 27.2951C92.5382 27.8281 92.414 28.2888 92.1656 28.6769C91.9171 29.0651 91.5704 29.3653 91.1253 29.5775C90.6854 29.7845 90.1781 29.888 89.6037 29.888Z" fill="#3C593D"/>
<path d="M79.8324 29.8841C79.0871 29.8841 78.4143 29.7029 77.8139 29.3406C77.2187 28.9732 76.7451 28.4711 76.3932 27.8345C76.0464 27.1979 75.873 26.4708 75.873 25.653C75.873 24.8456 76.0438 24.1262 76.3854 23.4948C76.7322 22.8582 77.2032 22.3562 77.7984 21.9887C78.3987 21.6212 79.0767 21.4375 79.8324 21.4375C80.5518 21.4375 81.2039 21.6057 81.7888 21.9421C82.3736 22.2785 82.8187 22.7443 83.1241 23.3395V21.6859H84.2575V29.6356H83.1241V27.9587C82.7825 28.5591 82.3244 29.0301 81.7499 29.3717C81.1754 29.7133 80.5363 29.8841 79.8324 29.8841ZM80.1119 28.8981C80.7071 28.8981 81.2324 28.761 81.6878 28.4867C82.1485 28.2072 82.5107 27.8242 82.7747 27.3377C83.0387 26.846 83.1706 26.287 83.1706 25.6608C83.1706 25.0294 83.0387 24.4704 82.7747 23.9839C82.5159 23.4974 82.1562 23.117 81.6956 22.8427C81.235 22.5632 80.7071 22.4235 80.1119 22.4235C79.5167 22.4235 78.9888 22.5632 78.5281 22.8427C78.0675 23.117 77.7052 23.4974 77.4413 23.9839C77.1773 24.4704 77.0453 25.0294 77.0453 25.6608C77.0453 26.287 77.1773 26.846 77.4413 27.3377C77.7052 27.8242 78.0675 28.2072 78.5281 28.4867C78.9888 28.761 79.5167 28.8981 80.1119 28.8981Z" fill="#3C593D"/>
<path d="M71.9658 29.6382V16.2852H73.0993V29.6382H71.9658Z" fill="#3C593D"/>
<path d="M68.1539 29.8864C67.5587 29.8864 67.0955 29.6871 66.7643 29.2886C66.4382 28.8849 66.2752 28.3182 66.2752 27.5884V22.5422H65.4678V21.6882H66.2752V18.7148H67.4086V21.6882H69.3883V22.5422H67.4086V27.5263C67.4086 27.9662 67.494 28.3026 67.6648 28.5355C67.8356 28.7684 68.0789 28.8849 68.3946 28.8849C68.6999 28.8849 68.9691 28.7995 69.202 28.6287L69.4892 29.5292C69.3132 29.6379 69.1062 29.7233 68.8681 29.7854C68.6301 29.8527 68.392 29.8864 68.1539 29.8864Z" fill="#3C593D"/>
<path d="M58.513 29.8841C57.7678 29.8841 57.0949 29.7029 56.4946 29.3406C55.8994 28.9732 55.4258 28.4711 55.0739 27.8345C54.7271 27.1979 54.5537 26.4708 54.5537 25.653C54.5537 24.8456 54.7245 24.1262 55.0661 23.4948C55.4129 22.8582 55.8838 22.3562 56.479 21.9887C57.0794 21.6212 57.7574 21.4375 58.513 21.4375C59.2324 21.4375 59.8846 21.6057 60.4694 21.9421C61.0543 22.2785 61.4994 22.7443 61.8047 23.3395V21.6859H62.9382V29.6356H61.8047V27.9587C61.4631 28.5591 61.0051 29.0301 60.4306 29.3717C59.8561 29.7133 59.2169 29.8841 58.513 29.8841ZM58.7925 28.8981C59.3877 28.8981 59.913 28.761 60.3685 28.4867C60.8291 28.2072 61.1914 27.8242 61.4554 27.3377C61.7193 26.846 61.8513 26.287 61.8513 25.6608C61.8513 25.0294 61.7193 24.4704 61.4554 23.9839C61.1966 23.4974 60.8369 23.117 60.3763 22.8427C59.9156 22.5632 59.3877 22.4235 58.7925 22.4235C58.1973 22.4235 57.6694 22.5632 57.2088 22.8427C56.7482 23.117 56.3859 23.4974 56.1219 23.9839C55.858 24.4704 55.726 25.0294 55.726 25.6608C55.726 26.287 55.858 26.846 56.1219 27.3377C56.3859 27.8242 56.7482 28.2072 57.2088 28.4867C57.6694 28.761 58.1973 28.8981 58.7925 28.8981Z" fill="#3C593D"/>
<path d="M5.41228 22.6607V0H6.76535V30.2143H5.41228L1.35307 7.55357V30.2143H0V0H1.35307L5.41228 22.6607Z" fill="#3C593D"/>
<path d="M13.6575 28.9006C14.024 28.9006 14.3341 28.7775 14.5878 28.5312C14.8697 28.2848 15.0106 27.9701 15.0106 27.587V2.62733C15.0106 2.27154 14.8697 1.9705 14.5878 1.72418C14.3341 1.4505 14.024 1.31366 13.6575 1.31366C13.2629 1.31366 12.9387 1.4505 12.685 1.72418C12.4313 1.9705 12.3045 2.27154 12.3045 2.62733V27.587C12.3045 27.9701 12.4313 28.2848 12.685 28.5312C12.9387 28.7775 13.2629 28.9006 13.6575 28.9006ZM13.6575 30.2143C12.8964 30.2143 12.2481 29.968 11.7125 29.4753C11.2051 28.9554 10.9514 28.3259 10.9514 27.587V2.62733C10.9514 1.91576 11.2051 1.29998 11.7125 0.779988C12.2481 0.259996 12.8964 0 13.6575 0C14.3905 0 15.0247 0.259996 15.5603 0.779988C16.0959 1.29998 16.3637 1.91576 16.3637 2.62733V27.587C16.3637 28.3259 16.0959 28.9554 15.5603 29.4753C15.0247 29.968 14.3905 30.2143 13.6575 30.2143Z" fill="#3C593D"/>
<path d="M28.3299 0H29.683V30.2143H28.3299V5.25466L24.9472 18.3913L21.5645 5.25466V30.2143H20.2115V0H21.5645L24.9472 13.1366L28.3299 0Z" fill="#3C593D"/>
<path d="M33.6999 30.2143V0H35.0529V30.2143H33.6999Z" fill="#3C593D"/>
<path d="M41.776 30.2143C41.0149 30.2143 40.3666 29.968 39.831 29.4753C39.3236 28.9554 39.0699 28.3259 39.0699 27.587V2.62733C39.0699 1.91576 39.3236 1.29998 39.831 0.779988C40.3666 0.259996 41.0149 0 41.776 0C42.5089 0 43.1432 0.259996 43.6788 0.779988C44.2143 1.29998 44.4821 1.91576 44.4821 2.62733V5.25466H43.1291V2.62733C43.1291 2.27154 42.9881 1.9705 42.7062 1.72418C42.4525 1.4505 42.1425 1.31366 41.776 1.31366C41.3814 1.31366 41.0572 1.4505 40.8035 1.72418C40.5498 1.9705 40.4229 2.27154 40.4229 2.62733V27.587C40.4229 27.9701 40.5498 28.2848 40.8035 28.5312C41.0572 28.7775 41.3814 28.9006 41.776 28.9006C42.1425 28.9006 42.4525 28.7775 42.7062 28.5312C42.9881 28.2848 43.1291 27.9701 43.1291 27.587V24.9596H44.4821V27.587C44.4821 28.3259 44.2143 28.9554 43.6788 29.4753C43.1432 29.968 42.5089 30.2143 41.776 30.2143Z" fill="#3C593D"/>
<path d="M56 1H91" stroke="#3C593D" stroke-linecap="round" stroke-dasharray="0.1 2"/>
</svg>

After

Width:  |  Height:  |  Size: 6.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

View File

@ -0,0 +1,28 @@
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
class _CommonNomic:
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [KeyError, InvokeBadRequestError],
}

View File

@ -0,0 +1,26 @@
import logging
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
logger = logging.getLogger(__name__)
class NomicAtlasProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.TEXT_EMBEDDING)
model_instance.validate_credentials(model="nomic-embed-text-v1.5", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex

View File

@ -0,0 +1,29 @@
provider: nomic
label:
zh_Hans: Nomic Atlas
en_US: Nomic Atlas
icon_small:
en_US: icon_s_en.png
icon_large:
en_US: icon_l_en.svg
background: "#EFF1FE"
help:
title:
en_US: Get your API key from Nomic Atlas
zh_Hans: 从Nomic Atlas获取 API Key
url:
en_US: https://atlas.nomic.ai/data
supported_model_types:
- text-embedding
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: nomic_api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key

View File

@ -0,0 +1,8 @@
model: nomic-embed-text-v1.5
model_type: text-embedding
model_properties:
context_size: 8192
pricing:
input: "0.1"
unit: "0.000001"
currency: USD

View File

@ -0,0 +1,8 @@
model: nomic-embed-text-v1
model_type: text-embedding
model_properties:
context_size: 8192
pricing:
input: "0.1"
unit: "0.000001"
currency: USD

View File

@ -0,0 +1,170 @@
import time
from functools import wraps
from typing import Optional
from nomic import embed
from nomic import login as nomic_login
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import (
EmbeddingUsage,
TextEmbeddingResult,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.text_embedding_model import (
TextEmbeddingModel,
)
from core.model_runtime.model_providers.nomic._common import _CommonNomic
def nomic_login_required(func):
@wraps(func)
def wrapper(*args, **kwargs):
try:
if not kwargs.get("credentials"):
raise ValueError("missing credentials parameters")
credentials = kwargs.get("credentials")
if "nomic_api_key" not in credentials:
raise ValueError("missing nomic_api_key in credentials parameters")
# nomic login
nomic_login(credentials["nomic_api_key"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
return func(*args, **kwargs)
return wrapper
class NomicTextEmbeddingModel(_CommonNomic, TextEmbeddingModel):
"""
Model class for nomic text embedding model.
"""
def _invoke(
self,
model: str,
credentials: dict,
texts: list[str],
user: Optional[str] = None,
) -> TextEmbeddingResult:
"""
Invoke text embedding model
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:param user: unique user id
:return: embeddings result
"""
embeddings, prompt_tokens, total_tokens = self.embed_text(
model=model,
credentials=credentials,
texts=texts,
)
# calc usage
usage = self._calc_response_usage(
model=model, credentials=credentials, tokens=prompt_tokens, total_tokens=total_tokens
)
return TextEmbeddingResult(embeddings=embeddings, usage=usage, model=model)
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
"""
Get number of tokens for given prompt messages
:param model: model name
:param credentials: model credentials
:param texts: texts to embed
:return:
"""
if len(texts) == 0:
return 0
_, prompt_tokens, _ = self.embed_text(
model=model,
credentials=credentials,
texts=texts,
)
return prompt_tokens
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# call embedding model
self.embed_text(model=model, credentials=credentials, texts=["ping"])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@nomic_login_required
def embed_text(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int, int]:
"""Call out to Nomic's embedding endpoint.
Args:
model: The model to use for embedding.
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text, and tokens usage.
"""
embeddings: list[list[float]] = []
prompt_tokens = 0
total_tokens = 0
response = embed.text(
model=model,
texts=texts,
)
if not (response and "embeddings" in response):
raise ValueError("Embedding data is missing in the response.")
if not (response and "usage" in response):
raise ValueError("Response usage is missing.")
if "prompt_tokens" not in response["usage"]:
raise ValueError("Response usage does not contain prompt tokens.")
if "total_tokens" not in response["usage"]:
raise ValueError("Response usage does not contain total tokens.")
embeddings = [list(map(float, e)) for e in response["embeddings"]]
total_tokens = response["usage"]["total_tokens"]
prompt_tokens = response["usage"]["prompt_tokens"]
return embeddings, prompt_tokens, total_tokens
def _calc_response_usage(self, model: str, credentials: dict, tokens: int, total_tokens: int) -> EmbeddingUsage:
"""
Calculate response usage
:param model: model name
:param credentials: model credentials
:param tokens: prompt tokens
:param total_tokens: total tokens
:return: usage
"""
# get input price info
input_price_info = self.get_price(
model=model,
credentials=credentials,
price_type=PriceType.INPUT,
tokens=tokens,
)
# transform usage
usage = EmbeddingUsage(
tokens=tokens,
total_tokens=total_tokens,
unit_price=input_price_info.unit_price,
price_unit=input_price_info.unit,
total_price=input_price_info.total_amount,
currency=input_price_info.currency,
latency=time.perf_counter() - self.started_at,
)
return usage

78
api/poetry.lock generated
View File

@ -4135,6 +4135,20 @@ files = [
{file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"},
]
[[package]]
name = "jsonlines"
version = "4.0.0"
description = "Library with helpers for the jsonlines file format"
optional = false
python-versions = ">=3.8"
files = [
{file = "jsonlines-4.0.0-py3-none-any.whl", hash = "sha256:185b334ff2ca5a91362993f42e83588a360cf95ce4b71a73548502bda52a7c55"},
{file = "jsonlines-4.0.0.tar.gz", hash = "sha256:0c6d2c09117550c089995247f605ae4cf77dd1533041d366351f6f298822ea74"},
]
[package.dependencies]
attrs = ">=19.2.0"
[[package]]
name = "jsonpath-ng"
version = "1.6.1"
@ -4469,6 +4483,24 @@ files = [
{file = "llvmlite-0.43.0.tar.gz", hash = "sha256:ae2b5b5c3ef67354824fb75517c8db5fbe93bc02cd9671f3c62271626bc041d5"},
]
[[package]]
name = "loguru"
version = "0.7.2"
description = "Python logging made (stupidly) simple"
optional = false
python-versions = ">=3.5"
files = [
{file = "loguru-0.7.2-py3-none-any.whl", hash = "sha256:003d71e3d3ed35f0f8984898359d65b79e5b21943f78af86aa5491210429b8eb"},
{file = "loguru-0.7.2.tar.gz", hash = "sha256:e671a53522515f34fd406340ee968cb9ecafbc4b36c679da03c18fd8d0bd51ac"},
]
[package.dependencies]
colorama = {version = ">=0.3.4", markers = "sys_platform == \"win32\""}
win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""}
[package.extras]
dev = ["Sphinx (==7.2.5)", "colorama (==0.4.5)", "colorama (==0.4.6)", "exceptiongroup (==1.1.3)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v1.4.1)", "mypy (==v1.5.1)", "pre-commit (==3.4.0)", "pytest (==6.1.2)", "pytest (==7.4.0)", "pytest-cov (==2.12.1)", "pytest-cov (==4.1.0)", "pytest-mypy-plugins (==1.9.3)", "pytest-mypy-plugins (==3.0.0)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.3.0)", "tox (==3.27.1)", "tox (==4.11.0)"]
[[package]]
name = "lxml"
version = "5.3.0"
@ -5320,6 +5352,36 @@ plot = ["matplotlib"]
tgrep = ["pyparsing"]
twitter = ["twython"]
[[package]]
name = "nomic"
version = "3.1.2"
description = "The official Nomic python client."
optional = false
python-versions = "*"
files = [
{file = "nomic-3.1.2.tar.gz", hash = "sha256:2de1ab1dcf2429011c92987bb2f1eafe1a3a4901c3185b18f994bf89616f606d"},
]
[package.dependencies]
click = "*"
jsonlines = "*"
loguru = "*"
numpy = "*"
pandas = "*"
pillow = "*"
pyarrow = "*"
pydantic = "*"
pyjwt = "*"
requests = "*"
rich = "*"
tqdm = "*"
[package.extras]
all = ["nomic[aws,local]"]
aws = ["boto3", "sagemaker"]
dev = ["black (==24.3.0)", "cairosvg", "coverage", "isort", "mkautodoc", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]", "myst-parser", "nomic[all]", "pandas", "pillow", "pylint", "pyright", "pytest", "pytorch-lightning", "twine"]
local = ["gpt4all (>=2.5.0,<3)"]
[[package]]
name = "novita-client"
version = "0.5.7"
@ -9919,6 +9981,20 @@ files = [
beautifulsoup4 = "*"
requests = ">=2.0.0,<3.0.0"
[[package]]
name = "win32-setctime"
version = "1.1.0"
description = "A small Python utility to set file creation time on Windows"
optional = false
python-versions = ">=3.5"
files = [
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
]
[package.extras]
dev = ["black (>=19.3b0)", "pytest (>=4.6.2)"]
[[package]]
name = "wrapt"
version = "1.16.0"
@ -10422,4 +10498,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.13"
content-hash = "eb7ef7be5c7790e214f37f17f92b69407ad557cb80055ef7e49e36eb51b3fca6"
content-hash = "17c4108d92c415d987f8b437ea3e0484c5601a05bfe175339a8546c93c159bc5"

View File

@ -100,6 +100,7 @@ exclude = [
OPENAI_API_KEY = "sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii"
UPSTAGE_API_KEY = "up-aaaaaaaaaaaaaaaaaaaa"
FIREWORKS_API_KEY = "fw_aaaaaaaaaaaaaaaaaaaa"
NOMIC_API_KEY = "nk-aaaaaaaaaaaaaaaaaaaa"
AZURE_OPENAI_API_BASE = "https://difyai-openai.openai.azure.com"
AZURE_OPENAI_API_KEY = "xxxxb1707exxxxxxxxxxaaxxxxxf94"
ANTHROPIC_API_KEY = "sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz"
@ -217,6 +218,7 @@ azure-ai-inference = "^1.0.0b3"
volcengine-python-sdk = {extras = ["ark"], version = "^1.0.98"}
oci = "^2.133.0"
tos = "^2.7.1"
nomic = "^3.1.2"
[tool.poetry.group.indriect.dependencies]
kaleido = "0.2.1"
rank-bm25 = "~0.2.2"

View File

@ -0,0 +1,59 @@
import os
from collections.abc import Callable
from typing import Any, Literal, Union
import pytest
# import monkeypatch
from _pytest.monkeypatch import MonkeyPatch
from nomic import embed
def create_embedding(texts: list[str], model: str, **kwargs: Any) -> dict:
texts_len = len(texts)
foo_embedding_sample = 0.123456
combined = {
"embeddings": [[foo_embedding_sample for _ in range(768)] for _ in range(texts_len)],
"usage": {"prompt_tokens": texts_len, "total_tokens": texts_len},
"model": model,
"inference_mode": "remote",
}
return combined
def mock_nomic(
monkeypatch: MonkeyPatch,
methods: list[Literal["text_embedding"]],
) -> Callable[[], None]:
"""
mock nomic module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""
def unpatch() -> None:
monkeypatch.undo()
if "text_embedding" in methods:
monkeypatch.setattr(embed, "text", create_embedding)
return unpatch
MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture
def setup_nomic_mock(request, monkeypatch):
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_nomic(monkeypatch, methods=methods)
yield
if MOCK:
unpatch()

View File

@ -0,0 +1,62 @@
import os
import pytest
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.nomic.text_embedding.text_embedding import NomicTextEmbeddingModel
from tests.integration_tests.model_runtime.__mock.nomic_embeddings import setup_nomic_mock
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
def test_validate_credentials(setup_nomic_mock):
model = NomicTextEmbeddingModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
model="nomic-embed-text-v1.5",
credentials={
"nomic_api_key": "invalid_key",
},
)
model.validate_credentials(
model="nomic-embed-text-v1.5",
credentials={
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
},
)
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
def test_invoke_model(setup_nomic_mock):
model = NomicTextEmbeddingModel()
result = model.invoke(
model="nomic-embed-text-v1.5",
credentials={
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
},
texts=["hello", "world"],
user="foo",
)
assert isinstance(result, TextEmbeddingResult)
assert result.model == "nomic-embed-text-v1.5"
assert len(result.embeddings) == 2
assert result.usage.total_tokens == 2
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
def test_get_num_tokens(setup_nomic_mock):
model = NomicTextEmbeddingModel()
num_tokens = model.get_num_tokens(
model="nomic-embed-text-v1.5",
credentials={
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
},
texts=["hello", "world"],
)
assert num_tokens == 2

View File

@ -0,0 +1,22 @@
import os
import pytest
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.nomic.nomic import NomicAtlasProvider
from core.model_runtime.model_providers.nomic.text_embedding.text_embedding import NomicTextEmbeddingModel
from tests.integration_tests.model_runtime.__mock.nomic_embeddings import setup_nomic_mock
@pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
def test_validate_provider_credentials(setup_nomic_mock):
provider = NomicAtlasProvider()
with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
"nomic_api_key": os.environ.get("NOMIC_API_KEY"),
},
)

View File

@ -7,4 +7,5 @@ pytest api/tests/integration_tests/model_runtime/anthropic \
api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \
api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py \
api/tests/integration_tests/model_runtime/upstage \
api/tests/integration_tests/model_runtime/fireworks
api/tests/integration_tests/model_runtime/fireworks \
api/tests/integration_tests/model_runtime/nomic