mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: add cohere llm and embedding (#2115)
This commit is contained in:
parent
8438d820ad
commit
a18dde9b0d
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from typing import Generator, List, Optional, Union
|
||||
|
@ -212,6 +213,10 @@ class LargeLanguageModel(AIModel):
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def enforce_stop_tokens(self, text: str, stop: List[str]) -> str:
|
||||
"""Cut off the text as soon as any stop words occur."""
|
||||
return re.split("|".join(stop), text, maxsplit=1)[0]
|
||||
|
||||
def _llm_result_to_stream(self, result: LLMResult) -> Generator:
|
||||
"""
|
||||
Transform llm result to stream
|
||||
|
|
|
@ -14,9 +14,12 @@ help:
|
|||
url:
|
||||
en_US: https://dashboard.cohere.com/api-keys
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
|
@ -26,6 +29,44 @@ provider_credential_schema:
|
|||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 请填写 API Key
|
||||
en_US: Please fill in API Key
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
show_on: [ ]
|
||||
model_credential_schema:
|
||||
model:
|
||||
label:
|
||||
en_US: Model Name
|
||||
zh_Hans: 模型名称
|
||||
placeholder:
|
||||
en_US: Enter your model name
|
||||
zh_Hans: 输入模型名称
|
||||
credential_form_schemas:
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
value: llm
|
||||
label:
|
||||
en_US: Completion mode
|
||||
type: select
|
||||
required: false
|
||||
default: chat
|
||||
placeholder:
|
||||
zh_Hans: 选择对话类型
|
||||
en_US: Select completion mode
|
||||
options:
|
||||
- value: completion
|
||||
label:
|
||||
en_US: Completion
|
||||
zh_Hans: 补全
|
||||
- value: chat
|
||||
label:
|
||||
en_US: Chat
|
||||
zh_Hans: 对话
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
- command-chat
|
||||
- command-light-chat
|
||||
- command-nightly-chat
|
||||
- command-light-nightly-chat
|
||||
- command
|
||||
- command-light
|
||||
- command-nightly
|
||||
- command-light-nightly
|
|
@ -0,0 +1,62 @@
|
|||
model: command-chat
|
||||
label:
|
||||
zh_Hans: command-chat
|
||||
en_US: command-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
zh_Hans: 前导文本
|
||||
en_US: Preamble
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
|
||||
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
required: false
|
||||
- name: prompt_truncation
|
||||
label:
|
||||
zh_Hans: 提示截断
|
||||
en_US: Prompt Truncation
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
|
||||
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
|
||||
required: true
|
||||
default: 'AUTO'
|
||||
options:
|
||||
- 'AUTO'
|
||||
- 'OFF'
|
||||
pricing:
|
||||
input: '1.0'
|
||||
output: '2.0'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,62 @@
|
|||
model: command-light-chat
|
||||
label:
|
||||
zh_Hans: command-light-chat
|
||||
en_US: command-light-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
zh_Hans: 前导文本
|
||||
en_US: Preamble
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
|
||||
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
required: false
|
||||
- name: prompt_truncation
|
||||
label:
|
||||
zh_Hans: 提示截断
|
||||
en_US: Prompt Truncation
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
|
||||
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
|
||||
required: true
|
||||
default: 'AUTO'
|
||||
options:
|
||||
- 'AUTO'
|
||||
- 'OFF'
|
||||
pricing:
|
||||
input: '0.3'
|
||||
output: '0.6'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,62 @@
|
|||
model: command-light-nightly-chat
|
||||
label:
|
||||
zh_Hans: command-light-nightly-chat
|
||||
en_US: command-light-nightly-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
zh_Hans: 前导文本
|
||||
en_US: Preamble
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
|
||||
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
required: false
|
||||
- name: prompt_truncation
|
||||
label:
|
||||
zh_Hans: 提示截断
|
||||
en_US: Prompt Truncation
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
|
||||
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
|
||||
required: true
|
||||
default: 'AUTO'
|
||||
options:
|
||||
- 'AUTO'
|
||||
- 'OFF'
|
||||
pricing:
|
||||
input: '0.3'
|
||||
output: '0.6'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,44 @@
|
|||
model: command-light-nightly
|
||||
label:
|
||||
zh_Hans: command-light-nightly
|
||||
en_US: command-light-nightly
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.3'
|
||||
output: '0.6'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,44 @@
|
|||
model: command-light
|
||||
label:
|
||||
zh_Hans: command-light
|
||||
en_US: command-light
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.3'
|
||||
output: '0.6'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,62 @@
|
|||
model: command-nightly-chat
|
||||
label:
|
||||
zh_Hans: command-nightly-chat
|
||||
en_US: command-nightly-chat
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
zh_Hans: 前导文本
|
||||
en_US: Preamble
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 当指定时,将使用提供的前导文本替换默认的 Cohere 前导文本。
|
||||
en_US: When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
required: false
|
||||
- name: prompt_truncation
|
||||
label:
|
||||
zh_Hans: 提示截断
|
||||
en_US: Prompt Truncation
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定如何构造 Prompt。当 prompt_truncation 设置为 "AUTO" 时,将会丢弃一些来自聊天记录的元素,以尝试构造一个符合模型上下文长度限制的 Prompt。
|
||||
en_US: Dictates how the prompt will be constructed. With prompt_truncation set to "AUTO", some elements from chat histories will be dropped in an attempt to construct a prompt that fits within the model's context length limit.
|
||||
required: true
|
||||
default: 'AUTO'
|
||||
options:
|
||||
- 'AUTO'
|
||||
- 'OFF'
|
||||
pricing:
|
||||
input: '1.0'
|
||||
output: '2.0'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,44 @@
|
|||
model: command-nightly
|
||||
label:
|
||||
zh_Hans: command-nightly
|
||||
en_US: command-nightly
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '1.0'
|
||||
output: '2.0'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,44 @@
|
|||
model: command
|
||||
label:
|
||||
zh_Hans: command
|
||||
en_US: command
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: completion
|
||||
context_size: 4096
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '1.0'
|
||||
output: '2.0'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
565
api/core/model_runtime/model_providers/cohere/llm/llm.py
Normal file
565
api/core/model_runtime/model_providers/cohere/llm/llm.py
Normal file
|
@ -0,0 +1,565 @@
|
|||
import logging
|
||||
from typing import Generator, List, Optional, Union, cast, Tuple
|
||||
|
||||
import cohere
|
||||
from cohere.responses import Chat, Generations
|
||||
from cohere.responses.chat import StreamingChat, StreamTextGeneration, StreamEnd
|
||||
from cohere.responses.generation import StreamingText, StreamingGenerations
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage,
|
||||
PromptMessageContentType, SystemPromptMessage,
|
||||
TextPromptMessageContent, UserPromptMessage,
|
||||
PromptMessageTool)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeError, \
|
||||
InvokeRateLimitError, InvokeAuthorizationError, InvokeBadRequestError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CohereLargeLanguageModel(LargeLanguageModel):
|
||||
"""
|
||||
Model class for Cohere large language model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model, credentials)
|
||||
|
||||
if model_mode == LLMMode.CHAT:
|
||||
return self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
else:
|
||||
return self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:
|
||||
"""
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
try:
|
||||
if model_mode == LLMMode.CHAT:
|
||||
return self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
else:
|
||||
return self._num_tokens_from_string(model, credentials, prompt_messages[0].content)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(model)
|
||||
|
||||
if model_mode == LLMMode.CHAT:
|
||||
self._chat_generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content='ping')],
|
||||
model_parameters={
|
||||
'max_tokens': 20,
|
||||
'temperature': 0,
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
else:
|
||||
self._generate(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=[UserPromptMessage(content='ping')],
|
||||
model_parameters={
|
||||
'max_tokens': 20,
|
||||
'temperature': 0,
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
if stop:
|
||||
model_parameters['end_sequences'] = stop
|
||||
|
||||
response = client.generate(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
|
||||
prompt_messages: list[PromptMessage]) \
|
||||
-> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
assistant_text = response.generations[0].text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
index = 1
|
||||
full_assistant_content = ''
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamingText):
|
||||
chunk = cast(StreamingText, chunk)
|
||||
text = chunk.text
|
||||
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
|
||||
full_assistant_content += text
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
index += 1
|
||||
elif chunk is None:
|
||||
# calculate num tokens
|
||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
finish_reason=response.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm chat model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
if user:
|
||||
model_parameters['user_name'] = user
|
||||
|
||||
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||
|
||||
# chat model
|
||||
real_model = model
|
||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
||||
real_model = model.removesuffix('-chat')
|
||||
|
||||
response = client.chat(
|
||||
message=message,
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
stream=stream,
|
||||
return_preamble=True,
|
||||
**model_parameters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
|
||||
prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \
|
||||
-> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop words
|
||||
:return: llm response
|
||||
"""
|
||||
assistant_text = response.text
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
completion_tokens = self._num_tokens_from_messages(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
if stop:
|
||||
# enforce stop tokens
|
||||
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.preamble
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[List[str]] = None) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
|
||||
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
|
||||
preamble: Optional[str] = None) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_text
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=preamble,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
index = 1
|
||||
full_assistant_content = ''
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamTextGeneration):
|
||||
chunk = cast(StreamTextGeneration, chunk)
|
||||
text = chunk.text
|
||||
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=text
|
||||
)
|
||||
|
||||
# stop
|
||||
# notice: This logic can only cover few stop scenarios
|
||||
if stop and text in stop:
|
||||
yield final_response(full_assistant_content, index, 'stop')
|
||||
break
|
||||
|
||||
full_assistant_content += text
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
)
|
||||
)
|
||||
|
||||
index += 1
|
||||
elif isinstance(chunk, StreamEnd):
|
||||
chunk = cast(StreamEnd, chunk)
|
||||
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
|
||||
index += 1
|
||||
|
||||
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
|
||||
-> Tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert prompt messages to message and chat histories
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
chat_histories = []
|
||||
for prompt_message in prompt_messages:
|
||||
chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
|
||||
|
||||
# get latest message from chat histories and pop it
|
||||
if len(chat_histories) > 0:
|
||||
latest_message = chat_histories.pop()
|
||||
message = latest_message['message']
|
||||
else:
|
||||
raise ValueError('Prompt messages is empty')
|
||||
|
||||
return message, chat_histories
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
"""
|
||||
Convert PromptMessage to dict for Cohere model
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "USER", "message": message.content}
|
||||
else:
|
||||
sub_message_text = ''
|
||||
for message_content in message.content:
|
||||
if message_content.type == PromptMessageContentType.TEXT:
|
||||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_text += message_content.data
|
||||
|
||||
message_dict = {"role": "USER", "message": sub_message_text}
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "CHATBOT", "message": message.content}
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "USER", "message": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name is not None:
|
||||
message_dict["user_name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
|
||||
"""
|
||||
Calculate num tokens for text completion model.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param text: prompt text
|
||||
:return: number of tokens
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
model=model
|
||||
)
|
||||
|
||||
return response.length
|
||||
|
||||
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int:
|
||||
"""Calculate num tokens Cohere model."""
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
|
||||
message_str = "\n".join(message_strs)
|
||||
|
||||
real_model = model
|
||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
||||
real_model = model.removesuffix('-chat')
|
||||
|
||||
return self._num_tokens_from_string(real_model, credentials, message_str)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity:
|
||||
"""
|
||||
Cohere supports fine-tuning of their models. This method returns the schema of the base model
|
||||
but renamed to the fine-tuned model name.
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
|
||||
:return: model schema
|
||||
"""
|
||||
# get model schema
|
||||
models = self.predefined_models()
|
||||
model_map = {model.model: model for model in models}
|
||||
|
||||
mode = credentials.get('mode')
|
||||
|
||||
if mode == 'chat':
|
||||
base_model_schema = model_map['command-light-chat']
|
||||
else:
|
||||
base_model_schema = model_map['command-light']
|
||||
|
||||
base_model_schema = cast(AIModelEntity, base_model_schema)
|
||||
|
||||
base_model_schema_features = base_model_schema.features or []
|
||||
base_model_schema_model_properties = base_model_schema.model_properties or {}
|
||||
base_model_schema_parameters_rules = base_model_schema.parameter_rules or []
|
||||
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
zh_Hans=model,
|
||||
en_US=model
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
features=[feature for feature in base_model_schema_features],
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
key: property for key, property in base_model_schema_model_properties.items()
|
||||
},
|
||||
parameter_rules=[rule for rule in base_model_schema_parameters_rules],
|
||||
pricing=base_model_schema.pricing
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
]
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
- embed-multilingual-v3.0
|
||||
- embed-multilingual-light-v3.0
|
||||
- embed-english-v3.0
|
||||
- embed-english-light-v3.0
|
||||
- embed-multilingual-v2.0
|
||||
- embed-english-v2.0
|
||||
- embed-english-light-v2.0
|
|
@ -0,0 +1,9 @@
|
|||
model: embed-english-light-v2.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: embed-english-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: embed-english-v2.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 4096
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: embed-english-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: embed-multilingual-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: embed-multilingual-v2.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 768
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,9 @@
|
|||
model: embed-multilingual-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.1'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,234 @@
|
|||
import time
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import cohere
|
||||
import numpy as np
|
||||
from cohere.responses import Tokens
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
|
||||
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
|
||||
class CohereTextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
embeddings: list[list[float]] = [[] for _ in range(len(texts))]
|
||||
tokens = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
tokenize_response = self._tokenize(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=text
|
||||
)
|
||||
|
||||
for j in range(0, tokenize_response.length, context_size):
|
||||
tokens += [tokenize_response.token_strings[j: j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(tokens), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=["".join(token) for token in tokens[i: i + max_chunks]]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
for i in range(len(indices)):
|
||||
results[indices[i]].append(batched_embeddings[i])
|
||||
num_tokens_in_batch[indices[i]].append(len(tokens[i]))
|
||||
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=[""]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
if len(texts) == 0:
|
||||
return 0
|
||||
|
||||
full_text = ' '.join(texts)
|
||||
|
||||
try:
|
||||
response = self._tokenize(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
text=full_text
|
||||
)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
return response.length
|
||||
|
||||
def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
|
||||
"""
|
||||
Tokenize text
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to tokenize
|
||||
:return:
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
model=model
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=['ping']
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
# call embedding model
|
||||
response = client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type='search_document' if len(texts) > 1 else 'search_query'
|
||||
)
|
||||
|
||||
return response.embeddings, response.meta['billed_units']['input_tokens']
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
]
|
||||
}
|
|
@ -24,6 +24,9 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
|||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
|
||||
if embedding_model_instance:
|
||||
embedding_model_type_instance = embedding_model_instance.model_type_instance
|
||||
embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
|
||||
|
|
|
@ -54,7 +54,7 @@ zhipuai==1.0.7
|
|||
werkzeug==2.3.8
|
||||
pymilvus==2.3.0
|
||||
qdrant-client==1.6.4
|
||||
cohere~=4.32
|
||||
cohere~=4.44
|
||||
pyyaml~=6.0.1
|
||||
numpy~=1.25.2
|
||||
unstructured[docx,pptx,msg,md,ppt]~=0.10.27
|
||||
|
|
272
api/tests/integration_tests/model_runtime/cohere/test_llm.py
Normal file
272
api/tests/integration_tests/model_runtime/cohere/test_llm.py
Normal file
|
@ -0,0 +1,272 @@
|
|||
import os
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (AssistantPromptMessage, SystemPromptMessage,
|
||||
UserPromptMessage)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
|
||||
|
||||
|
||||
def test_validate_credentials_for_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_for_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
credentials = {
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light',
|
||||
credentials=credentials,
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 1
|
||||
},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
assert model._num_tokens_from_string('command-light', credentials, result.message.content) == 1
|
||||
|
||||
|
||||
def test_invoke_stream_completion_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'p': 0.99,
|
||||
'presence_penalty': 0.0,
|
||||
'frequency_penalty': 0.0,
|
||||
'max_tokens': 10
|
||||
},
|
||||
stop=['How'],
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
for chunk in model._llm_result_to_stream(result):
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
|
||||
|
||||
def test_invoke_stream_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=True,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='command-light',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='command-light-chat',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 15
|
||||
|
||||
|
||||
def test_fine_tuned_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model='85ec47be-6139-4f75-a4be-0f0ec1ef115c-ft',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY'),
|
||||
'mode': 'completion'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
|
||||
|
||||
def test_fine_tuned_chat_model():
|
||||
model = CohereLargeLanguageModel()
|
||||
|
||||
# test invoke
|
||||
result = model.invoke(
|
||||
model='94f2d55a-4c79-4c00-bde4-23962e74b170-ft',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY'),
|
||||
'mode': 'chat'
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content='You are a helpful AI assistant.',
|
||||
),
|
||||
UserPromptMessage(
|
||||
content='Hello World!'
|
||||
)
|
||||
],
|
||||
model_parameters={
|
||||
'temperature': 0.0,
|
||||
'max_tokens': 100
|
||||
},
|
||||
stream=False,
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
|
@ -0,0 +1,64 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel
|
||||
|
||||
|
||||
def test_validate_credentials():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
model.validate_credentials(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': 'invalid_key'
|
||||
}
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_invoke_model():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
result = model.invoke(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world",
|
||||
" ".join(["long_text"] * 100),
|
||||
" ".join(["another_long_text"] * 100)
|
||||
],
|
||||
user="abc-123"
|
||||
)
|
||||
|
||||
assert isinstance(result, TextEmbeddingResult)
|
||||
assert len(result.embeddings) == 4
|
||||
assert result.usage.total_tokens == 811
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = CohereTextEmbeddingModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model='embed-multilingual-v3.0',
|
||||
credentials={
|
||||
'api_key': os.environ.get('COHERE_API_KEY')
|
||||
},
|
||||
texts=[
|
||||
"hello",
|
||||
"world"
|
||||
]
|
||||
)
|
||||
|
||||
assert num_tokens == 3
|
Loading…
Reference in New Issue
Block a user