mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: Add Cohere Command R / R+ model support (#3333)
This commit is contained in:
parent
bf63a43bda
commit
826c422ac4
|
@ -1,3 +1,5 @@
|
|||
- command-r
|
||||
- command-r-plus
|
||||
- command-chat
|
||||
- command-light-chat
|
||||
- command-nightly-chat
|
||||
|
|
|
@ -31,7 +31,7 @@ parameter_rules:
|
|||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
|
|
@ -31,7 +31,7 @@ parameter_rules:
|
|||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
|
|
@ -31,7 +31,7 @@ parameter_rules:
|
|||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
|
|
@ -35,7 +35,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.3'
|
||||
|
|
|
@ -35,7 +35,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.3'
|
||||
|
|
|
@ -31,7 +31,7 @@ parameter_rules:
|
|||
max: 500
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
- name: preamble_override
|
||||
label:
|
||||
|
|
|
@ -35,7 +35,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '1.0'
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
model: command-r-plus
|
||||
label:
|
||||
en_US: command-r-plus
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '3'
|
||||
output: '15'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -0,0 +1,45 @@
|
|||
model: command-r
|
||||
label:
|
||||
en_US: command-r
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
max: 5.0
|
||||
- name: p
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0.01
|
||||
max: 0.99
|
||||
- name: k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '0.5'
|
||||
output: '1.5'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
|
@ -35,7 +35,7 @@ parameter_rules:
|
|||
use_template: frequency_penalty
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 256
|
||||
default: 1024
|
||||
max: 4096
|
||||
pricing:
|
||||
input: '1.0'
|
||||
|
|
|
@ -1,20 +1,38 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import cohere
|
||||
from cohere.responses import Chat, Generations
|
||||
from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration
|
||||
from cohere.responses.generation import StreamingGenerations, StreamingText
|
||||
from cohere import (
|
||||
ChatMessage,
|
||||
ChatStreamRequestToolResultsItem,
|
||||
GenerateStreamedResponse,
|
||||
GenerateStreamedResponse_StreamEnd,
|
||||
GenerateStreamedResponse_StreamError,
|
||||
GenerateStreamedResponse_TextGeneration,
|
||||
Generation,
|
||||
NonStreamedChatResponse,
|
||||
StreamedChatResponse,
|
||||
StreamedChatResponse_StreamEnd,
|
||||
StreamedChatResponse_TextGeneration,
|
||||
StreamedChatResponse_ToolCallsGeneration,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolParameterDefinitionsValue,
|
||||
)
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType
|
||||
|
@ -64,6 +82,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
|
@ -159,19 +178,26 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
if stop:
|
||||
model_parameters['end_sequences'] = stop
|
||||
|
||||
response = client.generate(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
response = client.generate_stream(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
else:
|
||||
response = client.generate(
|
||||
prompt=prompt_messages[0].content,
|
||||
model=model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Generations,
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: Generation,
|
||||
prompt_messages: list[PromptMessage]) \
|
||||
-> LLMResult:
|
||||
"""
|
||||
|
@ -191,8 +217,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
||||
prompt_tokens = int(response.meta.billed_units.input_tokens)
|
||||
completion_tokens = int(response.meta.billed_units.output_tokens)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
@ -207,7 +233,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
return response
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: StreamingGenerations,
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: Iterator[GenerateStreamedResponse],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
@ -220,8 +246,8 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
index = 1
|
||||
full_assistant_content = ''
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamingText):
|
||||
chunk = cast(StreamingText, chunk)
|
||||
if isinstance(chunk, GenerateStreamedResponse_TextGeneration):
|
||||
chunk = cast(GenerateStreamedResponse_TextGeneration, chunk)
|
||||
text = chunk.text
|
||||
|
||||
if text is None:
|
||||
|
@ -244,10 +270,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
)
|
||||
|
||||
index += 1
|
||||
elif chunk is None:
|
||||
elif isinstance(chunk, GenerateStreamedResponse_StreamEnd):
|
||||
chunk = cast(GenerateStreamedResponse_StreamEnd, chunk)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = response.meta['billed_units']['input_tokens']
|
||||
completion_tokens = response.meta['billed_units']['output_tokens']
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
completion_tokens = self._num_tokens_from_messages(
|
||||
model,
|
||||
credentials,
|
||||
[AssistantPromptMessage(content=full_assistant_content)]
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
@ -258,14 +290,18 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
finish_reason=response.finish_reason,
|
||||
finish_reason=chunk.finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
break
|
||||
elif isinstance(chunk, GenerateStreamedResponse_StreamError):
|
||||
chunk = cast(GenerateStreamedResponse_StreamError, chunk)
|
||||
raise InvokeBadRequestError(chunk.err)
|
||||
|
||||
def _chat_generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm chat model
|
||||
|
@ -274,6 +310,7 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
:param credentials: credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
|
@ -282,31 +319,46 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
if user:
|
||||
model_parameters['user_name'] = user
|
||||
if stop:
|
||||
model_parameters['stop_sequences'] = stop
|
||||
|
||||
message, chat_histories = self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||
if tools:
|
||||
model_parameters['tools'] = self._convert_tools(tools)
|
||||
|
||||
message, chat_histories, tool_results \
|
||||
= self._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||
|
||||
if tool_results:
|
||||
model_parameters['tool_results'] = tool_results
|
||||
|
||||
# chat model
|
||||
real_model = model
|
||||
if self.get_model_schema(model, credentials).fetch_from == FetchFrom.PREDEFINED_MODEL:
|
||||
real_model = model.removesuffix('-chat')
|
||||
|
||||
response = client.chat(
|
||||
message=message,
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
stream=stream,
|
||||
**model_parameters,
|
||||
)
|
||||
|
||||
if stream:
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, stop)
|
||||
response = client.chat_stream(
|
||||
message=message,
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop)
|
||||
return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
else:
|
||||
response = client.chat(
|
||||
message=message,
|
||||
chat_history=chat_histories,
|
||||
model=real_model,
|
||||
**model_parameters,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat,
|
||||
prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \
|
||||
return self._handle_chat_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_chat_generate_response(self, model: str, credentials: dict, response: NonStreamedChatResponse,
|
||||
prompt_messages: list[PromptMessage]) \
|
||||
-> LLMResult:
|
||||
"""
|
||||
Handle llm chat response
|
||||
|
@ -315,14 +367,27 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop words
|
||||
:return: llm response
|
||||
"""
|
||||
assistant_text = response.text
|
||||
|
||||
tool_calls = []
|
||||
if response.tool_calls:
|
||||
for cohere_tool_call in response.tool_calls:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=cohere_tool_call.name,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=cohere_tool_call.name,
|
||||
arguments=json.dumps(cohere_tool_call.parameters)
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
content=assistant_text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
|
@ -332,44 +397,38 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
if stop:
|
||||
# enforce stop tokens
|
||||
assistant_text = self.enforce_stop_tokens(assistant_text, stop)
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=assistant_text
|
||||
)
|
||||
|
||||
# transform response
|
||||
response = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
system_fingerprint=response.preamble
|
||||
usage=usage
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat,
|
||||
prompt_messages: list[PromptMessage],
|
||||
stop: Optional[list[str]] = None) -> Generator:
|
||||
def _handle_chat_generate_stream_response(self, model: str, credentials: dict,
|
||||
response: Iterator[StreamedChatResponse],
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm chat stream response
|
||||
|
||||
:param model: model name
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop words
|
||||
:return: llm response chunk generator
|
||||
"""
|
||||
|
||||
def final_response(full_text: str, index: int, finish_reason: Optional[str] = None,
|
||||
preamble: Optional[str] = None) -> LLMResultChunk:
|
||||
def final_response(full_text: str,
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall],
|
||||
index: int,
|
||||
finish_reason: Optional[str] = None) -> LLMResultChunk:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self._num_tokens_from_messages(model, credentials, prompt_messages)
|
||||
|
||||
full_assistant_prompt_message = AssistantPromptMessage(
|
||||
content=full_text
|
||||
content=full_text,
|
||||
tool_calls=tool_calls
|
||||
)
|
||||
completion_tokens = self._num_tokens_from_messages(model, credentials, [full_assistant_prompt_message])
|
||||
|
||||
|
@ -379,10 +438,9 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
return LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=preamble,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=''),
|
||||
message=AssistantPromptMessage(content='', tool_calls=tool_calls),
|
||||
finish_reason=finish_reason,
|
||||
usage=usage
|
||||
)
|
||||
|
@ -390,9 +448,10 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
|
||||
index = 1
|
||||
full_assistant_content = ''
|
||||
tool_calls = []
|
||||
for chunk in response:
|
||||
if isinstance(chunk, StreamTextGeneration):
|
||||
chunk = cast(StreamTextGeneration, chunk)
|
||||
if isinstance(chunk, StreamedChatResponse_TextGeneration):
|
||||
chunk = cast(StreamedChatResponse_TextGeneration, chunk)
|
||||
text = chunk.text
|
||||
|
||||
if text is None:
|
||||
|
@ -403,12 +462,6 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
content=text
|
||||
)
|
||||
|
||||
# stop
|
||||
# notice: This logic can only cover few stop scenarios
|
||||
if stop and text in stop:
|
||||
yield final_response(full_assistant_content, index, 'stop')
|
||||
break
|
||||
|
||||
full_assistant_content += text
|
||||
|
||||
yield LLMResultChunk(
|
||||
|
@ -421,39 +474,98 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
)
|
||||
|
||||
index += 1
|
||||
elif isinstance(chunk, StreamEnd):
|
||||
chunk = cast(StreamEnd, chunk)
|
||||
yield final_response(full_assistant_content, index, chunk.finish_reason, response.preamble)
|
||||
elif isinstance(chunk, StreamedChatResponse_ToolCallsGeneration):
|
||||
chunk = cast(StreamedChatResponse_ToolCallsGeneration, chunk)
|
||||
|
||||
tool_calls = []
|
||||
if chunk.tool_calls:
|
||||
for cohere_tool_call in chunk.tool_calls:
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id=cohere_tool_call.name,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=cohere_tool_call.name,
|
||||
arguments=json.dumps(cohere_tool_call.parameters)
|
||||
)
|
||||
)
|
||||
tool_calls.append(tool_call)
|
||||
elif isinstance(chunk, StreamedChatResponse_StreamEnd):
|
||||
chunk = cast(StreamedChatResponse_StreamEnd, chunk)
|
||||
yield final_response(full_assistant_content, tool_calls, index, chunk.finish_reason)
|
||||
index += 1
|
||||
|
||||
def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \
|
||||
-> tuple[str, list[dict]]:
|
||||
-> tuple[str, list[ChatMessage], list[ChatStreamRequestToolResultsItem]]:
|
||||
"""
|
||||
Convert prompt messages to message and chat histories
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
chat_histories = []
|
||||
latest_tool_call_n_outputs = []
|
||||
for prompt_message in prompt_messages:
|
||||
chat_histories.append(self._convert_prompt_message_to_dict(prompt_message))
|
||||
if prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
prompt_message = cast(AssistantPromptMessage, prompt_message)
|
||||
if prompt_message.tool_calls:
|
||||
for tool_call in prompt_message.tool_calls:
|
||||
latest_tool_call_n_outputs.append(ChatStreamRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=tool_call.function.name,
|
||||
parameters=json.loads(tool_call.function.arguments)
|
||||
),
|
||||
outputs=[]
|
||||
))
|
||||
else:
|
||||
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
||||
if cohere_prompt_message:
|
||||
chat_histories.append(cohere_prompt_message)
|
||||
elif prompt_message.role == PromptMessageRole.TOOL:
|
||||
prompt_message = cast(ToolPromptMessage, prompt_message)
|
||||
if latest_tool_call_n_outputs:
|
||||
i = 0
|
||||
for tool_call_n_outputs in latest_tool_call_n_outputs:
|
||||
if tool_call_n_outputs.call.name == prompt_message.tool_call_id:
|
||||
latest_tool_call_n_outputs[i] = ChatStreamRequestToolResultsItem(
|
||||
call=ToolCall(
|
||||
name=tool_call_n_outputs.call.name,
|
||||
parameters=tool_call_n_outputs.call.parameters
|
||||
),
|
||||
outputs=[{
|
||||
"result": prompt_message.content
|
||||
}]
|
||||
)
|
||||
break
|
||||
i += 1
|
||||
else:
|
||||
cohere_prompt_message = self._convert_prompt_message_to_dict(prompt_message)
|
||||
if cohere_prompt_message:
|
||||
chat_histories.append(cohere_prompt_message)
|
||||
|
||||
if latest_tool_call_n_outputs:
|
||||
new_latest_tool_call_n_outputs = []
|
||||
for tool_call_n_outputs in latest_tool_call_n_outputs:
|
||||
if tool_call_n_outputs.outputs:
|
||||
new_latest_tool_call_n_outputs.append(tool_call_n_outputs)
|
||||
|
||||
latest_tool_call_n_outputs = new_latest_tool_call_n_outputs
|
||||
|
||||
# get latest message from chat histories and pop it
|
||||
if len(chat_histories) > 0:
|
||||
latest_message = chat_histories.pop()
|
||||
message = latest_message['message']
|
||||
message = latest_message.message
|
||||
else:
|
||||
raise ValueError('Prompt messages is empty')
|
||||
|
||||
return message, chat_histories
|
||||
return message, chat_histories, latest_tool_call_n_outputs
|
||||
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
|
||||
def _convert_prompt_message_to_dict(self, message: PromptMessage) -> Optional[ChatMessage]:
|
||||
"""
|
||||
Convert PromptMessage to dict for Cohere model
|
||||
"""
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message = cast(UserPromptMessage, message)
|
||||
if isinstance(message.content, str):
|
||||
message_dict = {"role": "USER", "message": message.content}
|
||||
chat_message = ChatMessage(role="USER", message=message.content)
|
||||
else:
|
||||
sub_message_text = ''
|
||||
for message_content in message.content:
|
||||
|
@ -461,20 +573,57 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
message_content = cast(TextPromptMessageContent, message_content)
|
||||
sub_message_text += message_content.data
|
||||
|
||||
message_dict = {"role": "USER", "message": sub_message_text}
|
||||
chat_message = ChatMessage(role="USER", message=sub_message_text)
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "CHATBOT", "message": message.content}
|
||||
if not message.content:
|
||||
return None
|
||||
chat_message = ChatMessage(role="CHATBOT", message=message.content)
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message = cast(SystemPromptMessage, message)
|
||||
message_dict = {"role": "USER", "message": message.content}
|
||||
chat_message = ChatMessage(role="USER", message=message.content)
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
if message.name:
|
||||
message_dict["user_name"] = message.name
|
||||
return chat_message
|
||||
|
||||
return message_dict
|
||||
def _convert_tools(self, tools: list[PromptMessageTool]) -> list[Tool]:
|
||||
"""
|
||||
Convert tools to Cohere model
|
||||
"""
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
properties = tool.parameters['properties']
|
||||
required_properties = tool.parameters['required']
|
||||
|
||||
parameter_definitions = {}
|
||||
for p_key, p_val in properties.items():
|
||||
required = False
|
||||
if property in required_properties:
|
||||
required = True
|
||||
|
||||
desc = p_val['description']
|
||||
if 'enum' in p_val:
|
||||
desc += (f"; Only accepts one of the following predefined options: "
|
||||
f"[{', '.join(p_val['enum'])}]")
|
||||
|
||||
parameter_definitions[p_key] = ToolParameterDefinitionsValue(
|
||||
description=desc,
|
||||
type=p_val['type'],
|
||||
required=required
|
||||
)
|
||||
|
||||
cohere_tool = Tool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameter_definitions=parameter_definitions
|
||||
)
|
||||
|
||||
cohere_tools.append(cohere_tool)
|
||||
|
||||
return cohere_tools
|
||||
|
||||
def _num_tokens_from_string(self, model: str, credentials: dict, text: str) -> int:
|
||||
"""
|
||||
|
@ -493,12 +642,16 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
model=model
|
||||
)
|
||||
|
||||
return response.length
|
||||
return len(response.tokens)
|
||||
|
||||
def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int:
|
||||
"""Calculate num tokens Cohere model."""
|
||||
messages = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
message_strs = [f"{message['role']}: {message['message']}" for message in messages]
|
||||
calc_messages = []
|
||||
for message in messages:
|
||||
cohere_message = self._convert_prompt_message_to_dict(message)
|
||||
if cohere_message:
|
||||
calc_messages.append(cohere_message)
|
||||
message_strs = [f"{message.role}: {message.message}" for message in calc_messages]
|
||||
message_str = "\n".join(message_strs)
|
||||
|
||||
real_model = model
|
||||
|
@ -564,13 +717,21 @@ class CohereLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
import cohere
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
|
@ -44,19 +45,21 @@ class CohereRerankModel(RerankModel):
|
|||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
results = client.rerank(
|
||||
response = client.rerank(
|
||||
query=query,
|
||||
documents=docs,
|
||||
model=model,
|
||||
top_n=top_n
|
||||
top_n=top_n,
|
||||
return_documents=True,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
rerank_documents = []
|
||||
for idx, result in enumerate(results):
|
||||
for idx, result in enumerate(response.results):
|
||||
# format document
|
||||
rerank_document = RerankDocument(
|
||||
index=result.index,
|
||||
text=result.document['text'],
|
||||
text=result.document.text,
|
||||
score=result.relevance_score,
|
||||
)
|
||||
|
||||
|
@ -108,13 +111,21 @@ class CohereRerankModel(RerankModel):
|
|||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError,
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
|
||||
import cohere
|
||||
import numpy as np
|
||||
from cohere.responses import Tokens
|
||||
from cohere.core import RequestOptions
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
|
@ -52,8 +52,8 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||
text=text
|
||||
)
|
||||
|
||||
for j in range(0, tokenize_response.length, context_size):
|
||||
tokens += [tokenize_response.token_strings[j: j + context_size]]
|
||||
for j in range(0, len(tokenize_response), context_size):
|
||||
tokens += [tokenize_response[j: j + context_size]]
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
|
@ -127,9 +127,9 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
return response.length
|
||||
return len(response)
|
||||
|
||||
def _tokenize(self, model: str, credentials: dict, text: str) -> Tokens:
|
||||
def _tokenize(self, model: str, credentials: dict, text: str) -> list[str]:
|
||||
"""
|
||||
Tokenize text
|
||||
:param model: model name
|
||||
|
@ -138,17 +138,19 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||
:return:
|
||||
"""
|
||||
if not text:
|
||||
return Tokens([], [], {})
|
||||
return []
|
||||
|
||||
# initialize client
|
||||
client = cohere.Client(credentials.get('api_key'))
|
||||
|
||||
response = client.tokenize(
|
||||
text=text,
|
||||
model=model
|
||||
model=model,
|
||||
offline=False,
|
||||
request_options=RequestOptions(max_retries=0)
|
||||
)
|
||||
|
||||
return response
|
||||
return response.token_strings
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
|
@ -184,10 +186,11 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||
response = client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type='search_document' if len(texts) > 1 else 'search_query'
|
||||
input_type='search_document' if len(texts) > 1 else 'search_query',
|
||||
request_options=RequestOptions(max_retries=1)
|
||||
)
|
||||
|
||||
return response.embeddings, response.meta['billed_units']['input_tokens']
|
||||
return response.embeddings, int(response.meta.billed_units.input_tokens)
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
|
@ -231,13 +234,21 @@ class CohereTextEmbeddingModel(TextEmbeddingModel):
|
|||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
cohere.CohereConnectionError
|
||||
cohere.errors.service_unavailable_error.ServiceUnavailableError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
cohere.errors.internal_server_error.InternalServerError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
cohere.errors.too_many_requests_error.TooManyRequestsError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
cohere.errors.unauthorized_error.UnauthorizedError,
|
||||
cohere.errors.forbidden_error.ForbiddenError
|
||||
],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: [
|
||||
cohere.CohereAPIError,
|
||||
cohere.CohereError,
|
||||
cohere.core.api_error.ApiError,
|
||||
cohere.errors.bad_request_error.BadRequestError,
|
||||
cohere.errors.not_found_error.NotFoundError,
|
||||
]
|
||||
}
|
||||
|
|
|
@ -232,8 +232,8 @@ class SimplePromptTransform(PromptTransform):
|
|||
)
|
||||
),
|
||||
max_token_limit=rest_tokens,
|
||||
ai_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
human_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
human_prefix=prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human',
|
||||
ai_prefix=prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
|
||||
)
|
||||
|
||||
# get prompt
|
||||
|
|
|
@ -47,7 +47,8 @@ replicate~=0.22.0
|
|||
websocket-client~=1.7.0
|
||||
dashscope[tokenizer]~=1.14.0
|
||||
huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
transformers~=4.35.0
|
||||
tokenizers~=0.15.0
|
||||
pandas==1.5.3
|
||||
xinference-client==0.9.4
|
||||
safetensors==0.3.2
|
||||
|
@ -55,7 +56,7 @@ zhipuai==1.0.7
|
|||
werkzeug~=3.0.1
|
||||
pymilvus==2.3.0
|
||||
qdrant-client==1.7.3
|
||||
cohere~=4.44
|
||||
cohere~=5.2.4
|
||||
pyyaml~=6.0.1
|
||||
numpy~=1.25.2
|
||||
unstructured[docx,pptx,msg,md,ppt]~=0.10.27
|
||||
|
|
Loading…
Reference in New Issue
Block a user