From 200010be19fc434e3010718255a5cb0323c8243a Mon Sep 17 00:00:00 2001 From: kerlion <40377268+kerlion@users.noreply.github.com> Date: Tue, 16 Apr 2024 13:22:38 +0800 Subject: [PATCH] Add suuport for AWS Bedrock Cohere embedding (#3444) --- .../bedrock/text_embedding/_position.yaml | 2 + .../cohere.embed-english-v3.yaml | 8 +++ .../cohere.embed-multilingual-v3.yaml | 8 +++ .../bedrock/text_embedding/text_embedding.py | 55 ++++++++++++++----- 4 files changed, 58 insertions(+), 15 deletions(-) create mode 100644 api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml create mode 100644 api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml index e657e2a270..5419ff530b 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/_position.yaml @@ -1 +1,3 @@ - amazon.titan-embed-text-v1 +- cohere.embed-english-v3 +- cohere.embed-multilingual-v3 diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml new file mode 100644 index 0000000000..d49aa2a99c --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-english-v3.yaml @@ -0,0 +1,8 @@ +model: cohere.embed-english-v3 +model_type: text-embedding +model_properties: + context_size: 512 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml new file mode 100644 index 0000000000..63bab59d2c --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/cohere.embed-multilingual-v3.yaml @@ -0,0 +1,8 @@ +model: cohere.embed-multilingual-v3 +model_type: text-embedding +model_properties: + context_size: 512 +pricing: + input: '0.1' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 0cd28e3655..69436cd737 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -1,4 +1,5 @@ import json +import logging import time from typing import Optional @@ -24,6 +25,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +logger = logging.getLogger(__name__) class BedrockTextEmbeddingModel(TextEmbeddingModel): @@ -53,17 +55,19 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): embeddings = [] token_usage = 0 - + model_prefix = model.split('.')[0] - if model_prefix == "amazon": - for text in texts: - body = { - "inputText": text, - } - response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) - embeddings.extend([response_body.get('embedding')]) - token_usage += response_body.get('inputTextTokenCount') - result = TextEmbeddingResult( + + if model_prefix == "amazon" : + for text in texts: + body = { + "inputText": text, + } + response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) + embeddings.extend([response_body.get('embedding')]) + token_usage += response_body.get('inputTextTokenCount') + logger.warning(f'Total Tokens: {token_usage}') + result = TextEmbeddingResult( model=model, embeddings=embeddings, usage=self._calc_response_usage( @@ -71,11 +75,32 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel): credentials=credentials, tokens=token_usage ) - ) - else: - raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") - - return result + ) + return result + + if model_prefix == "cohere" : + input_type = 'search_document' if len(texts) > 1 else 'search_query' + for text in texts: + body = { + "texts": [text], + "input_type": input_type, + } + response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body) + embeddings.extend(response_body.get('embeddings')) + token_usage += len(text) + result = TextEmbeddingResult( + model=model, + embeddings=embeddings, + usage=self._calc_response_usage( + model=model, + credentials=credentials, + tokens=token_usage + ) + ) + return result + + #others + raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response") def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: