mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 19:59:50 +08:00
Add bedrock command r models (#4521)
Co-authored-by: Justin Wu <justin.wu@ringcentral.com> Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
This commit is contained in:
parent
07387e9586
commit
61f4f08744
|
@ -17,6 +17,10 @@ class CotAgentOutputParser:
|
||||||
action_name = None
|
action_name = None
|
||||||
action_input = None
|
action_input = None
|
||||||
|
|
||||||
|
# cohere always returns a list
|
||||||
|
if isinstance(action, list) and len(action) == 1:
|
||||||
|
action = action[0]
|
||||||
|
|
||||||
for key, value in action.items():
|
for key, value in action.items():
|
||||||
if 'input' in key.lower():
|
if 'input' in key.lower():
|
||||||
action_input = value
|
action_input = value
|
||||||
|
|
|
@ -8,6 +8,8 @@
|
||||||
- anthropic.claude-3-haiku-v1:0
|
- anthropic.claude-3-haiku-v1:0
|
||||||
- cohere.command-light-text-v14
|
- cohere.command-light-text-v14
|
||||||
- cohere.command-text-v14
|
- cohere.command-text-v14
|
||||||
|
- cohere.command-r-plus-v1.0
|
||||||
|
- cohere.command-r-v1.0
|
||||||
- meta.llama3-8b-instruct-v1:0
|
- meta.llama3-8b-instruct-v1:0
|
||||||
- meta.llama3-70b-instruct-v1:0
|
- meta.llama3-70b-instruct-v1:0
|
||||||
- meta.llama2-13b-chat-v1
|
- meta.llama2-13b-chat-v1
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
model: cohere.command-r-plus-v1:0
|
||||||
|
label:
|
||||||
|
en_US: Command R+
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
#- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
#- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
max: 5.0
|
||||||
|
- name: p
|
||||||
|
use_template: top_p
|
||||||
|
default: 0.75
|
||||||
|
min: 0.01
|
||||||
|
max: 0.99
|
||||||
|
- name: k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
min: 0
|
||||||
|
max: 500
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
max: 4096
|
||||||
|
pricing:
|
||||||
|
input: '3'
|
||||||
|
output: '15'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -0,0 +1,45 @@
|
||||||
|
model: cohere.command-r-v1:0
|
||||||
|
label:
|
||||||
|
en_US: Command R
|
||||||
|
model_type: llm
|
||||||
|
features:
|
||||||
|
#- multi-tool-call
|
||||||
|
- agent-thought
|
||||||
|
#- stream-tool-call
|
||||||
|
model_properties:
|
||||||
|
mode: chat
|
||||||
|
context_size: 128000
|
||||||
|
parameter_rules:
|
||||||
|
- name: temperature
|
||||||
|
use_template: temperature
|
||||||
|
max: 5.0
|
||||||
|
- name: p
|
||||||
|
use_template: top_p
|
||||||
|
default: 0.75
|
||||||
|
min: 0.01
|
||||||
|
max: 0.99
|
||||||
|
- name: k
|
||||||
|
label:
|
||||||
|
zh_Hans: 取样数量
|
||||||
|
en_US: Top k
|
||||||
|
type: int
|
||||||
|
help:
|
||||||
|
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||||
|
en_US: Only sample from the top K options for each subsequent token.
|
||||||
|
required: false
|
||||||
|
default: 0
|
||||||
|
min: 0
|
||||||
|
max: 500
|
||||||
|
- name: presence_penalty
|
||||||
|
use_template: presence_penalty
|
||||||
|
- name: frequency_penalty
|
||||||
|
use_template: frequency_penalty
|
||||||
|
- name: max_tokens
|
||||||
|
use_template: max_tokens
|
||||||
|
default: 1024
|
||||||
|
max: 4096
|
||||||
|
pricing:
|
||||||
|
input: '0.5'
|
||||||
|
output: '1.5'
|
||||||
|
unit: '0.000001'
|
||||||
|
currency: USD
|
|
@ -25,6 +25,7 @@ from botocore.exceptions import (
|
||||||
ServiceNotInRegionError,
|
ServiceNotInRegionError,
|
||||||
UnknownServiceError,
|
UnknownServiceError,
|
||||||
)
|
)
|
||||||
|
from cohere import ChatMessage
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
|
@ -48,6 +49,7 @@ from core.model_runtime.errors.invoke import (
|
||||||
)
|
)
|
||||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||||
|
from core.model_runtime.model_providers.cohere.llm.llm import CohereLargeLanguageModel
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -75,9 +77,87 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
# invoke anthropic models via anthropic official SDK
|
# invoke anthropic models via anthropic official SDK
|
||||||
if "anthropic" in model:
|
if "anthropic" in model:
|
||||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||||
|
# invoke Cohere models via boto3 client
|
||||||
|
if "cohere.command-r" in model:
|
||||||
|
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||||
# invoke other models via boto3 client
|
# invoke other models via boto3 client
|
||||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||||
|
|
||||||
|
def _generate_cohere_chat(
|
||||||
|
self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||||
|
tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||||
|
cohere_llm = CohereLargeLanguageModel()
|
||||||
|
client_config = Config(
|
||||||
|
region_name=credentials["aws_region"]
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime_client = boto3.client(
|
||||||
|
service_name='bedrock-runtime',
|
||||||
|
config=client_config,
|
||||||
|
aws_access_key_id=credentials["aws_access_key_id"],
|
||||||
|
aws_secret_access_key=credentials["aws_secret_access_key"]
|
||||||
|
)
|
||||||
|
|
||||||
|
extra_model_kwargs = {}
|
||||||
|
if stop:
|
||||||
|
extra_model_kwargs['stop_sequences'] = stop
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
tools = cohere_llm._convert_tools(tools)
|
||||||
|
model_parameters['tools'] = tools
|
||||||
|
|
||||||
|
message, chat_histories, tool_results \
|
||||||
|
= cohere_llm._convert_prompt_messages_to_message_and_chat_histories(prompt_messages)
|
||||||
|
|
||||||
|
if tool_results:
|
||||||
|
model_parameters['tool_results'] = tool_results
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
**model_parameters,
|
||||||
|
"message": message,
|
||||||
|
"chat_history": chat_histories,
|
||||||
|
}
|
||||||
|
|
||||||
|
# need workaround for ai21 models which doesn't support streaming
|
||||||
|
if stream:
|
||||||
|
invoke = runtime_client.invoke_model_with_response_stream
|
||||||
|
else:
|
||||||
|
invoke = runtime_client.invoke_model
|
||||||
|
|
||||||
|
def serialize(obj):
|
||||||
|
if isinstance(obj, ChatMessage):
|
||||||
|
return obj.__dict__
|
||||||
|
raise TypeError(f"Type {type(obj)} not serializable")
|
||||||
|
|
||||||
|
try:
|
||||||
|
body_jsonstr=json.dumps(payload, default=serialize)
|
||||||
|
response = invoke(
|
||||||
|
modelId=model,
|
||||||
|
contentType="application/json",
|
||||||
|
accept="*/*",
|
||||||
|
body=body_jsonstr
|
||||||
|
)
|
||||||
|
except ClientError as ex:
|
||||||
|
error_code = ex.response['Error']['Code']
|
||||||
|
full_error_msg = f"{error_code}: {ex.response['Error']['Message']}"
|
||||||
|
raise self._map_client_to_invoke_error(error_code, full_error_msg)
|
||||||
|
|
||||||
|
except (EndpointConnectionError, NoRegionError, ServiceNotInRegionError) as ex:
|
||||||
|
raise InvokeConnectionError(str(ex))
|
||||||
|
|
||||||
|
except UnknownServiceError as ex:
|
||||||
|
raise InvokeServerUnavailableError(str(ex))
|
||||||
|
|
||||||
|
except Exception as ex:
|
||||||
|
raise InvokeError(str(ex))
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
|
||||||
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user