mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
Add suuport for AWS Bedrock Cohere embedding (#3444)
This commit is contained in:
parent
5e02a83b53
commit
200010be19
|
@ -1 +1,3 @@
|
|||
- amazon.titan-embed-text-v1
|
||||
- cohere.embed-english-v3
|
||||
- cohere.embed-multilingual-v3
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
model: cohere.embed-english-v3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,8 @@
|
|||
model: cohere.embed-multilingual-v3
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 512
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
|
@ -24,6 +25,7 @@ from core.model_runtime.errors.invoke import (
|
|||
)
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
||||
|
||||
|
@ -55,7 +57,8 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|||
token_usage = 0
|
||||
|
||||
model_prefix = model.split('.')[0]
|
||||
if model_prefix == "amazon":
|
||||
|
||||
if model_prefix == "amazon" :
|
||||
for text in texts:
|
||||
body = {
|
||||
"inputText": text,
|
||||
|
@ -63,6 +66,7 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend([response_body.get('embedding')])
|
||||
token_usage += response_body.get('inputTextTokenCount')
|
||||
logger.warning(f'Total Tokens: {token_usage}')
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
|
@ -72,11 +76,32 @@ class BedrockTextEmbeddingModel(TextEmbeddingModel):
|
|||
tokens=token_usage
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
|
||||
return result
|
||||
|
||||
if model_prefix == "cohere" :
|
||||
input_type = 'search_document' if len(texts) > 1 else 'search_query'
|
||||
for text in texts:
|
||||
body = {
|
||||
"texts": [text],
|
||||
"input_type": input_type,
|
||||
}
|
||||
response_body = self._invoke_bedrock_embedding(model, bedrock_runtime, body)
|
||||
embeddings.extend(response_body.get('embeddings'))
|
||||
token_usage += len(text)
|
||||
result = TextEmbeddingResult(
|
||||
model=model,
|
||||
embeddings=embeddings,
|
||||
usage=self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=token_usage
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
#others
|
||||
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
|
||||
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user