feat: Add Cohere Command R / R+ model support (#3333)

This commit is contained in:
takatost 2024-04-11 01:22:55 +08:00 committed by GitHub
parent bf63a43bda
commit 826c422ac4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 404 additions and 128 deletions

View File

@ -1,3 +1,5 @@
- command-r
- command-r-plus
- command-chat
- command-light-chat
- command-nightly-chat

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '0.3'

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '0.3'

View File

@ -31,7 +31,7 @@ parameter_rules:
max: 500
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
- name: preamble_override
label:

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '1.0'

View File

@ -0,0 +1,45 @@
model: command-r-plus
label:
en_US: command-r-plus
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '3'
output: '15'
unit: '0.000001'
currency: USD

View File

@ -0,0 +1,45 @@
model: command-r
label:
en_US: command-r
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 128000
parameter_rules:
- name: temperature
use_template: temperature
max: 5.0
- name: p
use_template: top_p
default: 0.75
min: 0.01
max: 0.99
- name: k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 0
min: 0
max: 500
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 1024
max: 4096
pricing:
input: '0.5'
output: '1.5'
unit: '0.000001'
currency: USD

View File

@ -35,7 +35,7 @@ parameter_rules:
use_template: frequency_penalty
- name: max_tokens
use_template: max_tokens
default: 256
default: 1024
max: 4096
pricing:
input: '1.0'

View File

@ -1,20 +1,38 @@
import json
import logging
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import Optional, Union, cast
import cohere
from cohere.responses import Chat, Generations
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
from cohere.responses.generation import StreamingGenerations, StreamingText
from cohere import (
ChatMessage,
ChatStreamRequestToolResultsItem,
GenerateStreamedResponse,
GenerateStreamedResponse_StreamEnd,
GenerateStreamedResponse_StreamError,
GenerateStreamedResponse_TextGeneration,
Generation,
NonStreamedChatResponse,
StreamedChatResponse,
StreamedChatResponse_StreamEnd,
StreamedChatResponse_TextGeneration,
StreamedChatResponse_ToolCallsGeneration,
Tool,
ToolCall,
ToolParameterDefinitionsValue,
)
from cohere.core import RequestOptions
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentType,
PromptMessageRole,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
credentials=credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=stream,
user=user
@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
if stop:
model_parameters['end_sequences'] = stop
response = client.generate(
prompt=prompt_messages[0].content,
model=model,
stream=stream,
**model_parameters,
)
if stream:
response = client.generate_stream(
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
else:
response = client.generate(
prompt=prompt_messages[0].content,
model=model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_generate_response(model, credentials, response, prompt_messages)
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
prompt_messages: list[PromptMessage]) \
-> LLMResult:
"""
@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
# calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens']
completion_tokens = response.meta['billed_units']['output_tokens']
prompt_tokens = int(response.meta.billed_units.input_tokens)
completion_tokens = int(response.meta.billed_units.output_tokens)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return response
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index = 1
full_assistant_content = ''
for chunk in response:
if isinstance(chunk, StreamingText):
chunk = cast(StreamingText, chunk)
if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
text = chunk.text
if text is None:
@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
index += 1
elif chunk is None:
elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
# calculate num tokens
prompt_tokens = response.meta['billed_units']['input_tokens']
completion_tokens = response.meta['billed_units']['output_tokens']
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
completion_tokens = self._num_tokens_from_messages(
model,
credentials,
[AssistantPromptMessage(content=full_assistant_content)]
)
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
finish_reason=response.finish_reason,
finish_reason=chunk.finish_reason,
usage=usage
)
)
break
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
raise InvokeBadRequestError(chunk.err)
def _chat_generate(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm chat model
@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:param credentials: credentials
:param prompt_messages: prompt messages
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
@ -282,31 +319,46 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# initialize client
client = cohere.Client(credentials.get('api_key'))
if user:
model_parameters['user_name'] = user
if stop:
model_parameters['stop_sequences'] = stop
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
if tools:
model_parameters['tools'] = self._convert_tools(tools)
message, chat_histories, tool_results \
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
if tool_results:
model_parameters['tool_results'] = tool_results
# chat model
real_model = model
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
real_model = model.removesuffix('-chat')
response = client.chat(
message=message,
chat_history=chat_histories,
model=real_model,
stream=stream,
**model_parameters,
)
if stream:
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
response = client.chat_stream(
message=message,
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
else:
response = client.chat(
message=message,
chat_history=chat_histories,
model=real_model,
**model_parameters,
request_options=RequestOptions(max_retries=0)
)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
prompt_messages: list[PromptMessage]) \
-> LLMResult:
"""
Handle llm chat response
@ -315,14 +367,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
:param credentials: credentials
:param response: response
:param prompt_messages: prompt messages
:param stop: stop words
:return: llm response
"""
assistant_text = response.text
tool_calls = []
if response.tool_calls:
for cohere_tool_call in response.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
)
tool_calls.append(tool_call)
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
content=assistant_text,
tool_calls=tool_calls
)
# calculate num tokens
@ -332,44 +397,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
# transform usage
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
if stop:
# enforce stop tokens
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
assistant_prompt_message = AssistantPromptMessage(
content=assistant_text
)
# transform response
response = LLMResult(
model=model,
prompt_messages=prompt_messages,
message=assistant_prompt_message,
usage=usage,
system_fingerprint=response.preamble
usage=usage
)
return response
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
prompt_messages: list[PromptMessage],
stop: Optional[list[str]] = None) -> Generator:
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
response: Iterator[StreamedChatResponse],
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm chat stream response
:param model: model name
:param response: response
:param prompt_messages: prompt messages
:param stop: stop words
:return: llm response chunk generator
"""
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
preamble: Optional[str] = None) -> LLMResultChunk:
def final_response(full_text: str,
tool_calls: list[AssistantPromptMessage.ToolCall],
index: int,
finish_reason: Optional[str] = None) -> LLMResultChunk:
# calculate num tokens
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
full_assistant_prompt_message = AssistantPromptMessage(
content=full_text
content=full_text,
tool_calls=tool_calls
)
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
@ -379,10 +438,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
return LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
system_fingerprint=preamble,
delta=LLMResultChunkDelta(
index=index,
message=AssistantPromptMessage(content=''),
message=AssistantPromptMessage(content='', tool_calls=tool_calls),
finish_reason=finish_reason,
usage=usage
)
@ -390,9 +448,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
index = 1
full_assistant_content = ''
tool_calls = []
for chunk in response:
if isinstance(chunk, StreamTextGeneration):
chunk = cast(StreamTextGeneration, chunk)
if isinstance(chunk, StreamedChatResponse_TextGeneration):
chunk = cast(StreamedChatResponse_TextGeneration, chunk)
text = chunk.text
if text is None:
@ -403,12 +462,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
content=text
)
# stop
# notice: This logic can only cover few stop scenarios
if stop and text in stop:
yield final_response(full_assistant_content, index, 'stop')
break
full_assistant_content += text
yield LLMResultChunk(
@ -421,39 +474,98 @@ class CohereLargeLanguageModel(LargeLanguageModel):
)
index += 1
elif isinstance(chunk, StreamEnd):
chunk = cast(StreamEnd, chunk)
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
tool_calls = []
if chunk.tool_calls:
for cohere_tool_call in chunk.tool_calls:
tool_call = AssistantPromptMessage.ToolCall(
id=cohere_tool_call.name,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=cohere_tool_call.name,
arguments=json.dumps(cohere_tool_call.parameters)
)
)
tool_calls.append(tool_call)
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
chunk = cast(StreamedChatResponse_StreamEnd, chunk)
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
index += 1
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
-> tuple[str, list[dict]]:
-> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
"""
Convert prompt messages to message and chat histories
:param prompt_messages: prompt messages
:return:
"""
chat_histories = []
latest_tool_call_n_outputs = []
for prompt_message in prompt_messages:
chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
if prompt_message.role == PromptMessageRole.ASSISTANT:
prompt_message = cast(AssistantPromptMessage, prompt_message)
if prompt_message.tool_calls:
for tool_call in prompt_message.tool_calls:
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call.function.name,
parameters=json.loads(tool_call.function.arguments)
),
outputs=[]
))
else:
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
if cohere_prompt_message:
chat_histories.append(cohere_prompt_message)
elif prompt_message.role == PromptMessageRole.TOOL:
prompt_message = cast(ToolPromptMessage, prompt_message)
if latest_tool_call_n_outputs:
i = 0
for tool_call_n_outputs in latest_tool_call_n_outputs:
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
call=ToolCall(
name=tool_call_n_outputs.call.name,
parameters=tool_call_n_outputs.call.parameters
),
outputs=[{
"result": prompt_message.content
}]
)
break
i += 1
else:
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
if cohere_prompt_message:
chat_histories.append(cohere_prompt_message)
if latest_tool_call_n_outputs:
new_latest_tool_call_n_outputs = []
for tool_call_n_outputs in latest_tool_call_n_outputs:
if tool_call_n_outputs.outputs:
new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
# get latest message from chat histories and pop it
if len(chat_histories) > 0:
latest_message = chat_histories.pop()
message = latest_message['message']
message = latest_message.message
else:
raise ValueError('Prompt messages is empty')
return message, chat_histories
return message, chat_histories, latest_tool_call_n_outputs
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
"""
Convert PromptMessage to dict for Cohere model
"""
if isinstance(message, UserPromptMessage):
message = cast(UserPromptMessage, message)
if isinstance(message.content, str):
message_dict = {"role": "USER", "message": message.content}
chat_message = ChatMessage(role="USER", message=message.content)
else:
sub_message_text = ''
for message_content in message.content:
@ -461,20 +573,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
message_content = cast(TextPromptMessageContent, message_content)
sub_message_text += message_content.data
message_dict = {"role": "USER", "message": sub_message_text}
chat_message = ChatMessage(role="USER", message=sub_message_text)
elif isinstance(message, AssistantPromptMessage):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "CHATBOT", "message": message.content}
if not message.content:
return None
chat_message = ChatMessage(role="CHATBOT", message=message.content)
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "USER", "message": message.content}
chat_message = ChatMessage(role="USER", message=message.content)
elif isinstance(message, ToolPromptMessage):
return None
else:
raise ValueError(f"Got unknown type {message}")
if message.name:
message_dict["user_name"] = message.name
return chat_message
return message_dict
def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
"""
Convert tools to Cohere model
"""
cohere_tools = []
for tool in tools:
properties = tool.parameters['properties']
required_properties = tool.parameters['required']
parameter_definitions = {}
for p_key, p_val in properties.items():
required = False
if property in required_properties:
required = True
desc = p_val['description']
if 'enum' in p_val:
desc += (f"; Only accepts one of the following predefined options: "
f"[{', '.join(p_val['enum'])}]")
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
description=desc,
type=p_val['type'],
required=required
)
cohere_tool = Tool(
name=tool.name,
description=tool.description,
parameter_definitions=parameter_definitions
)
cohere_tools.append(cohere_tool)
return cohere_tools
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
"""
@ -493,12 +642,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
model=model
)
return response.length
return len(response.tokens)
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
"""Calculate num tokens Cohere model."""
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
calc_messages = []
for message in messages:
cohere_message = self._convert_prompt_message_to_dict(message)
if cohere_message:
calc_messages.append(cohere_message)
message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
message_str = "\n".join(message_strs)
real_model = model
@ -564,13 +717,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
}

View File

@ -1,6 +1,7 @@
from typing import Optional
import cohere
from cohere.core import RequestOptions
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
@ -44,19 +45,21 @@ class CohereRerankModel(RerankModel):
# initialize client
client = cohere.Client(credentials.get('api_key'))
results = client.rerank(
response = client.rerank(
query=query,
documents=docs,
model=model,
top_n=top_n
top_n=top_n,
return_documents=True,
request_options=RequestOptions(max_retries=0)
)
rerank_documents = []
for idx, result in enumerate(results):
for idx, result in enumerate(response.results):
# format document
rerank_document = RerankDocument(
index=result.index,
text=result.document['text'],
text=result.document.text,
score=result.relevance_score,
)
@ -108,13 +111,21 @@ class CohereRerankModel(RerankModel):
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError,
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
}

View File

@ -3,7 +3,7 @@ from typing import Optional
import cohere
import numpy as np
from cohere.responses import Tokens
from cohere.core import RequestOptions
from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
@ -52,8 +52,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
text=text
)
for j in range(0, tokenize_response.length, context_size):
tokens += [tokenize_response.token_strings[j: j + context_size]]
for j in range(0, len(tokenize_response), context_size):
tokens += [tokenize_response[j: j + context_size]]
indices += [i]
batched_embeddings = []
@ -127,9 +127,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
except Exception as e:
raise self._transform_invoke_error(e)
return response.length
return len(response)
def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
"""
Tokenize text
:param model: model name
@ -138,17 +138,19 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
:return:
"""
if not text:
return Tokens([], [], {})
return []
# initialize client
client = cohere.Client(credentials.get('api_key'))
response = client.tokenize(
text=text,
model=model
model=model,
offline=False,
request_options=RequestOptions(max_retries=0)
)
return response
return response.token_strings
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
@ -184,10 +186,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
response = client.embed(
texts=texts,
model=model,
input_type='search_document' if len(texts) > 1 else 'search_query'
input_type='search_document' if len(texts) > 1 else 'search_query',
request_options=RequestOptions(max_retries=1)
)
return response.embeddings, response.meta['billed_units']['input_tokens']
return response.embeddings, int(response.meta.billed_units.input_tokens)
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
"""
@ -231,13 +234,21 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
"""
return {
InvokeConnectionError: [
cohere.CohereConnectionError
cohere.errors.service_unavailable_error.ServiceUnavailableError
],
InvokeServerUnavailableError: [
cohere.errors.internal_server_error.InternalServerError
],
InvokeRateLimitError: [
cohere.errors.too_many_requests_error.TooManyRequestsError
],
InvokeAuthorizationError: [
cohere.errors.unauthorized_error.UnauthorizedError,
cohere.errors.forbidden_error.ForbiddenError
],
InvokeServerUnavailableError: [],
InvokeRateLimitError: [],
InvokeAuthorizationError: [],
InvokeBadRequestError: [
cohere.CohereAPIError,
cohere.CohereError,
cohere.core.api_error.ApiError,
cohere.errors.bad_request_error.BadRequestError,
cohere.errors.not_found_error.NotFoundError,
]
}

View File

@ -232,8 +232,8 @@ class SimplePromptTransform(PromptTransform):
)
),
max_token_limit=rest_tokens,
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
)
# get prompt

View File

@ -47,7 +47,8 @@ replicate~=0.22.0
websocket-client~=1.7.0
dashscope[tokenizer]~=1.14.0
huggingface_hub~=0.16.4
transformers~=4.31.0
transformers~=4.35.0
tokenizers~=0.15.0
pandas==1.5.3
xinference-client==0.9.4
safetensors==0.3.2
@ -55,7 +56,7 @@ zhipuai==1.0.7
werkzeug~=3.0.1
pymilvus==2.3.0
qdrant-client==1.7.3
cohere~=4.44
cohere~=5.2.4
pyyaml~=6.0.1
numpy~=1.25.2
unstructured[docx,pptx,msg,md,ppt]~=0.10.27