mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
feat: bedrock model runtime enhancement (#6299)
This commit is contained in:
parent
cc0c826f36
commit
ed9e692263
|
@ -48,6 +48,28 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BedrockLargeLanguageModel(LargeLanguageModel):
|
class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
|
|
||||||
|
# please refer to the documentation: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html
|
||||||
|
# TODO There is invoke issue: context limit on Cohere Model, will add them after fixed.
|
||||||
|
CONVERSE_API_ENABLED_MODEL_INFO=[
|
||||||
|
{'prefix': 'anthropic.claude-v2', 'support_system_prompts': True, 'support_tool_use': False},
|
||||||
|
{'prefix': 'anthropic.claude-v1', 'support_system_prompts': True, 'support_tool_use': False},
|
||||||
|
{'prefix': 'anthropic.claude-3', 'support_system_prompts': True, 'support_tool_use': True},
|
||||||
|
{'prefix': 'meta.llama', 'support_system_prompts': True, 'support_tool_use': False},
|
||||||
|
{'prefix': 'mistral.mistral-7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||||
|
{'prefix': 'mistral.mixtral-8x7b-instruct', 'support_system_prompts': False, 'support_tool_use': False},
|
||||||
|
{'prefix': 'mistral.mistral-large', 'support_system_prompts': True, 'support_tool_use': True},
|
||||||
|
{'prefix': 'mistral.mistral-small', 'support_system_prompts': True, 'support_tool_use': True},
|
||||||
|
{'prefix': 'amazon.titan', 'support_system_prompts': False, 'support_tool_use': False}
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _find_model_info(model_id):
|
||||||
|
for model in BedrockLargeLanguageModel.CONVERSE_API_ENABLED_MODEL_INFO:
|
||||||
|
if model_id.startswith(model['prefix']):
|
||||||
|
return model
|
||||||
|
logger.info(f"current model id: {model_id} did not support by Converse API")
|
||||||
|
return None
|
||||||
|
|
||||||
def _invoke(self, model: str, credentials: dict,
|
def _invoke(self, model: str, credentials: dict,
|
||||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||||
|
@ -66,10 +88,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
:param user: unique user id
|
:param user: unique user id
|
||||||
:return: full response or stream response chunk generator result
|
:return: full response or stream response chunk generator result
|
||||||
"""
|
"""
|
||||||
# TODO: consolidate different invocation methods for models based on base model capabilities
|
|
||||||
# invoke anthropic models via boto3 client
|
model_info= BedrockLargeLanguageModel._find_model_info(model)
|
||||||
if "anthropic" in model:
|
if model_info:
|
||||||
return self._generate_anthropic(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
model_info['model'] = model
|
||||||
|
# invoke models via boto3 converse API
|
||||||
|
return self._generate_with_converse(model_info, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||||
# invoke Cohere models via boto3 client
|
# invoke Cohere models via boto3 client
|
||||||
if "cohere.command-r" in model:
|
if "cohere.command-r" in model:
|
||||||
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
return self._generate_cohere_chat(model, credentials, prompt_messages, model_parameters, stop, stream, user, tools)
|
||||||
|
@ -151,12 +175,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||||
|
|
||||||
|
|
||||||
def _generate_anthropic(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
def _generate_with_converse(self, model_info: dict, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, tools: Optional[list[PromptMessageTool]] = None,) -> Union[LLMResult, Generator]:
|
||||||
"""
|
"""
|
||||||
Invoke Anthropic large language model
|
Invoke large language model with converse API
|
||||||
|
|
||||||
:param model: model name
|
:param model_info: model information
|
||||||
:param credentials: model credentials
|
:param credentials: model credentials
|
||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:param model_parameters: model parameters
|
:param model_parameters: model parameters
|
||||||
|
@ -173,24 +197,24 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)
|
||||||
|
|
||||||
parameters = {
|
parameters = {
|
||||||
'modelId': model,
|
'modelId': model_info['model'],
|
||||||
'messages': prompt_message_dicts,
|
'messages': prompt_message_dicts,
|
||||||
'inferenceConfig': inference_config,
|
'inferenceConfig': inference_config,
|
||||||
'additionalModelRequestFields': additional_model_fields,
|
'additionalModelRequestFields': additional_model_fields,
|
||||||
}
|
}
|
||||||
|
|
||||||
if system and len(system) > 0:
|
if model_info['support_system_prompts'] and system and len(system) > 0:
|
||||||
parameters['system'] = system
|
parameters['system'] = system
|
||||||
|
|
||||||
if tools:
|
if model_info['support_tool_use'] and tools:
|
||||||
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
|
parameters['toolConfig'] = self._convert_converse_tool_config(tools=tools)
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
response = bedrock_client.converse_stream(**parameters)
|
response = bedrock_client.converse_stream(**parameters)
|
||||||
return self._handle_converse_stream_response(model, credentials, response, prompt_messages)
|
return self._handle_converse_stream_response(model_info['model'], credentials, response, prompt_messages)
|
||||||
else:
|
else:
|
||||||
response = bedrock_client.converse(**parameters)
|
response = bedrock_client.converse(**parameters)
|
||||||
return self._handle_converse_response(model, credentials, response, prompt_messages)
|
return self._handle_converse_response(model_info['model'], credentials, response, prompt_messages)
|
||||||
|
|
||||||
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
def _handle_converse_response(self, model: str, credentials: dict, response: dict,
|
||||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||||
|
@ -203,10 +227,30 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
:param prompt_messages: prompt messages
|
:param prompt_messages: prompt messages
|
||||||
:return: full response chunk generator result
|
:return: full response chunk generator result
|
||||||
"""
|
"""
|
||||||
|
response_content = response['output']['message']['content']
|
||||||
# transform assistant message to prompt message
|
# transform assistant message to prompt message
|
||||||
assistant_prompt_message = AssistantPromptMessage(
|
if response['stopReason'] == 'tool_use':
|
||||||
content=response['output']['message']['content'][0]['text']
|
tool_calls = []
|
||||||
)
|
text, tool_use = self._extract_tool_use(response_content)
|
||||||
|
|
||||||
|
tool_call = AssistantPromptMessage.ToolCall(
|
||||||
|
id=tool_use['toolUseId'],
|
||||||
|
type='function',
|
||||||
|
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||||
|
name=tool_use['name'],
|
||||||
|
arguments=json.dumps(tool_use['input'])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
tool_calls.append(tool_call)
|
||||||
|
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=text,
|
||||||
|
tool_calls=tool_calls
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assistant_prompt_message = AssistantPromptMessage(
|
||||||
|
content=response_content[0]['text']
|
||||||
|
)
|
||||||
|
|
||||||
# calculate num tokens
|
# calculate num tokens
|
||||||
if response['usage']:
|
if response['usage']:
|
||||||
|
@ -229,6 +273,18 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _extract_tool_use(self, content:dict)-> tuple[str, dict]:
|
||||||
|
tool_use = {}
|
||||||
|
text = ''
|
||||||
|
for item in content:
|
||||||
|
if 'toolUse' in item:
|
||||||
|
tool_use = item['toolUse']
|
||||||
|
elif 'text' in item:
|
||||||
|
text = item['text']
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown item: {item}")
|
||||||
|
return text, tool_use
|
||||||
|
|
||||||
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
|
def _handle_converse_stream_response(self, model: str, credentials: dict, response: dict,
|
||||||
prompt_messages: list[PromptMessage], ) -> Generator:
|
prompt_messages: list[PromptMessage], ) -> Generator:
|
||||||
"""
|
"""
|
||||||
|
@ -340,14 +396,12 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system = []
|
system = []
|
||||||
|
prompt_message_dicts = []
|
||||||
for message in prompt_messages:
|
for message in prompt_messages:
|
||||||
if isinstance(message, SystemPromptMessage):
|
if isinstance(message, SystemPromptMessage):
|
||||||
message.content=message.content.strip()
|
message.content=message.content.strip()
|
||||||
system.append({"text": message.content})
|
system.append({"text": message.content})
|
||||||
|
else:
|
||||||
prompt_message_dicts = []
|
|
||||||
for message in prompt_messages:
|
|
||||||
if not isinstance(message, SystemPromptMessage):
|
|
||||||
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
prompt_message_dicts.append(self._convert_prompt_message_to_dict(message))
|
||||||
|
|
||||||
return system, prompt_message_dicts
|
return system, prompt_message_dicts
|
||||||
|
@ -448,7 +502,6 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unknown type {message}")
|
raise ValueError(f"Got unknown type {message}")
|
||||||
|
|
||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
|
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage] | str,
|
||||||
|
|
|
@ -2,6 +2,9 @@ model: mistral.mistral-large-2402-v1:0
|
||||||
label:
|
label:
|
||||||
en_US: Mistral Large
|
en_US: Mistral Large
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- tool-call
|
||||||
|
- agent-thought
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: completion
|
mode: completion
|
||||||
context_size: 32000
|
context_size: 32000
|
||||||
|
|
|
@ -2,6 +2,8 @@ model: mistral.mistral-small-2402-v1:0
|
||||||
label:
|
label:
|
||||||
en_US: Mistral Small
|
en_US: Mistral Small
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
features:
|
||||||
|
- tool-call
|
||||||
model_properties:
|
model_properties:
|
||||||
mode: completion
|
mode: completion
|
||||||
context_size: 32000
|
context_size: 32000
|
||||||
|
|
Loading…
Reference in New Issue
Block a user