From 3e63abd3352832c0494a07a7660800672f5fd1bd Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Mon, 26 Feb 2024 23:34:40 +0800 Subject: [PATCH] Feat/json mode (#2563) --- api/core/model_runtime/entities/defaults.py | 13 + .../model_runtime/entities/model_entities.py | 1 + .../model_providers/__base/ai_model.py | 16 +- .../__base/large_language_model.py | 258 +++++++++++++++++- .../anthropic/llm/claude-2.1.yaml | 2 + .../anthropic/llm/claude-2.yaml | 2 + .../anthropic/llm/claude-instant-1.yaml | 2 + .../model_providers/anthropic/llm/llm.py | 57 +++- .../google/llm/gemini-pro.yaml | 2 + .../model_providers/google/llm/llm.py | 12 +- .../openai/llm/gpt-3.5-turbo-0125.yaml | 12 + .../openai/llm/gpt-3.5-turbo-0613.yaml | 2 + .../openai/llm/gpt-3.5-turbo-1106.yaml | 12 + .../openai/llm/gpt-3.5-turbo-16k-0613.yaml | 2 + .../openai/llm/gpt-3.5-turbo-16k.yaml | 2 + .../openai/llm/gpt-3.5-turbo-instruct.yaml | 2 + .../openai/llm/gpt-3.5-turbo.yaml | 12 + .../model_providers/openai/llm/llm.py | 134 +++++++++ .../model_providers/tongyi/llm/llm.py | 88 +++++- .../tongyi/llm/qwen-max-1201.yaml | 2 + .../tongyi/llm/qwen-max-longcontext.yaml | 2 + .../model_providers/tongyi/llm/qwen-max.yaml | 2 + .../model_providers/tongyi/llm/qwen-plus.yaml | 2 + .../tongyi/llm/qwen-turbo.yaml | 2 + .../wenxin/llm/ernie-bot-4.yaml | 2 + .../wenxin/llm/ernie-bot-8k.yaml | 2 + .../wenxin/llm/ernie-bot-turbo.yaml | 2 + .../model_providers/wenxin/llm/ernie-bot.yaml | 2 + .../model_providers/wenxin/llm/llm.py | 71 ++++- .../model_providers/zhipuai/llm/llm.py | 51 +++- .../model_runtime/wenxin/test_llm.py | 22 +- 31 files changed, 762 insertions(+), 31 deletions(-) diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 856f4ce7d1..776f6802e6 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { 'min': 1, 'max': 2048, 'precision': 0, + }, + DefaultParameterName.RESPONSE_FORMAT: { + 'label': { + 'en_US': 'Response Format', + 'zh_Hans': '回复格式', + }, + 'type': 'string', + 'help': { + 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', + 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + }, + 'required': False, + 'options': ['JSON', 'XML'], } } \ No newline at end of file diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index e35be27f86..52c2d66f9f 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -91,6 +91,7 @@ class DefaultParameterName(Enum): PRESENCE_PENALTY = "presence_penalty" FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" + RESPONSE_FORMAT = "response_format" @classmethod def value_of(cls, value: Any) -> 'DefaultParameterName': diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index a9f7a539e2..026e6eca21 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -262,23 +262,23 @@ class AIModel(ABC): try: default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max: + if not parameter_rule.max and 'max' in default_parameter_rule: parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min: + if not parameter_rule.min and 'min' in default_parameter_rule: parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.precision: + if not parameter_rule.default and 'default' in default_parameter_rule: parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision: + if not parameter_rule.precision and 'precision' in default_parameter_rule: parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required: + if not parameter_rule.required and 'required' in default_parameter_rule: parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help: + if not parameter_rule.help and 'help' in default_parameter_rule: parameter_rule.help = I18nObject( en_US=default_parameter_rule['help']['en_US'], ) - if not parameter_rule.help.en_US: + if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']): parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] - if not parameter_rule.help.zh_Hans: + if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']): parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) except ValueError: pass diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 1f7edd245f..4b546a5356 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -9,7 +9,13 @@ from typing import Optional, Union from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ( ModelPropertyKey, ModelType, @@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel): ) try: - result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + if "response_format" in model_parameters: + result = self._code_block_mode_wrapper( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks + ) + else: + result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) except Exception as e: self._trigger_invoke_error_callbacks( model=model, @@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel): return result + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper, ensure the response is a code block with output markdown quote + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :param callbacks: callbacks + :return: full response or stream response chunk generator result + """ + + block_prompts = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + code_block = model_parameters.get("response_format", "") + if not code_block: + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + model_parameters.pop("response_format") + stop = stop or [] + stop.extend(["\n```", "```\n"]) + block_prompts = block_prompts.replace("{{block}}", code_block) + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", prompt_messages[0].content) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", f"Please output a valid {code_block} object.") + )) + + if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # add ```JSON\n to the last message + prompt_messages[-1].content += f"\n```{code_block}\n" + else: + # append a user message + prompt_messages.append(UserPromptMessage( + content=f"```{code_block}\n" + )) + + response = self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + if isinstance(response, Generator): + first_chunk = next(response) + def new_generator(): + yield first_chunk + yield from response + + if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): + return self._code_block_mode_stream_processor_with_backtick( + model=model, + prompt_messages=prompt_messages, + input_generator=new_generator() + ) + else: + return self._code_block_mode_stream_processor( + model=model, + prompt_messages=prompt_messages, + input_generator=new_generator() + ) + + return response + + def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], + input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: + """ + Code block mode stream processor, ensure the response is a code block with output markdown quote + + :param model: model name + :param prompt_messages: prompt messages + :param input_generator: input generator + :return: output generator + """ + state = "normal" + backtick_count = 0 + for piece in input_generator: + if piece.delta.message.content: + content = piece.delta.message.content + piece.delta.message.content = "" + yield piece + piece = content + else: + yield piece + continue + new_piece = "" + for char in piece: + if state == "normal": + if char == "`": + state = "in_backticks" + backtick_count = 1 + else: + new_piece += char + elif state == "in_backticks": + if char == "`": + backtick_count += 1 + if backtick_count == 3: + state = "skip_content" + backtick_count = 0 + else: + new_piece += "`" * backtick_count + char + state = "normal" + backtick_count = 0 + elif state == "skip_content": + if char.isspace(): + state = "normal" + + if new_piece: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=new_piece, + tool_calls=[] + ), + ) + ) + + def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, + input_generator: Generator[LLMResultChunk, None, None]) \ + -> Generator[LLMResultChunk, None, None]: + """ + Code block mode stream processor, ensure the response is a code block with output markdown quote. + This version skips the language identifier that follows the opening triple backticks. + + :param model: model name + :param prompt_messages: prompt messages + :param input_generator: input generator + :return: output generator + """ + state = "search_start" + backtick_count = 0 + + for piece in input_generator: + if piece.delta.message.content: + content = piece.delta.message.content + # Reset content to ensure we're only processing and yielding the relevant parts + piece.delta.message.content = "" + # Yield a piece with cleared content before processing it to maintain the generator structure + yield piece + piece = content + else: + # Yield pieces without content directly + yield piece + continue + + if state == "done": + continue + + new_piece = "" + for char in piece: + if state == "search_start": + if char == "`": + backtick_count += 1 + if backtick_count == 3: + state = "skip_language" + backtick_count = 0 + else: + backtick_count = 0 + elif state == "skip_language": + # Skip everything until the first newline, marking the end of the language identifier + if char == "\n": + state = "in_code_block" + elif state == "in_code_block": + if char == "`": + backtick_count += 1 + if backtick_count == 3: + state = "done" + break + else: + if backtick_count > 0: + # If backticks were counted but we're still collecting content, it was a false start + new_piece += "`" * backtick_count + backtick_count = 0 + new_piece += char + + elif state == "done": + break + + if new_piece: + # Only yield content collected within the code block + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=new_piece, + tool_calls=[] + ), + ) + ) + def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, @@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel): :return: full response or stream response chunk generator result """ raise NotImplementedError - + @abstractmethod def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml index 08beef3caa..6707c34594 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml @@ -27,6 +27,8 @@ parameter_rules: default: 4096 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '8.00' output: '24.00' diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml index 3c49067630..12faf60bc9 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml @@ -27,6 +27,8 @@ parameter_rules: default: 4096 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '8.00' output: '24.00' diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml index d44859faa3..25d32a09af 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml @@ -26,6 +26,8 @@ parameter_rules: default: 4096 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '1.63' output: '5.51' diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index c743708896..00e5ef6fda 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream from anthropic.types import Completion, completion_create_params from httpx import Timeout +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, @@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if 'response_format' in model_parameters and model_parameters['response_format']: + stop = stop or [] + self._transform_json_prompts( + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format'] + ) + model_parameters.pop('response_format') + + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _transform_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts + """ + if "```\n" not in stop: + stop.append("```\n") + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) + )) + + prompt_messages.append(AssistantPromptMessage( + content=f"```{response_format}\n" + )) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml index 3b98e615e6..ffdc9c3659 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml @@ -27,6 +27,8 @@ parameter_rules: default: 2048 min: 1 max: 2048 + - name: response_format + use_template: response_format pricing: input: '0.00' output: '0.00' diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 686761ab5f..2feff8ebe9 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) +GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, @@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml index 3e40db01f9..c1602b2efc 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml @@ -24,6 +24,18 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.0005' output: '0.0015' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml index 6d519cbee6..31dc53e89f 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml @@ -24,6 +24,8 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '0.0015' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml index 499792e39d..56ab965c39 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml @@ -24,6 +24,18 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.001' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml index a86bacb34f..4a0e2ef191 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml @@ -24,6 +24,8 @@ parameter_rules: default: 512 min: 1 max: 16385 + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.004' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml index 467041e842..3684c1945c 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml @@ -24,6 +24,8 @@ parameter_rules: default: 512 min: 1 max: 16385 + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.004' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml index 926ee05d97..ad831539e0 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml @@ -21,6 +21,8 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '0.0015' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml index fddf1836c4..4ffd31a814 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml @@ -24,6 +24,18 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.001' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 2a1137d443..2ea65780f1 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI logger = logging.getLogger(__name__) +OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ @@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): user=user ) + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + # handle fine tune remote models + base_model = model + if model.startswith('ft:'): + base_model = model.split(':')[1] + + # get model mode + model_mode = self.get_model_mode(base_model, credentials) + + # transform response format + if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + stop = stop or [] + if model_mode == LLMMode.CHAT: + # chat model + self._transform_chat_json_prompts( + model=base_model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters['response_format'] + ) + else: + self._transform_completion_json_prompts( + model=base_model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters['response_format'] + ) + model_parameters.pop('response_format') + + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + def _transform_chat_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts + """ + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) + ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) + )) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + + def _transform_completion_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts + """ + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + + # override the last user message + user_message = None + for i in range(len(prompt_messages) - 1, -1, -1): + if isinstance(prompt_messages[i], UserPromptMessage): + user_message = prompt_messages[i] + break + + if user_message: + if prompt_messages[i].content[-11:] == 'Assistant: ': + # now we are in the chat app, remove the last assistant message + prompt_messages[i].content = prompt_messages[i].content[:-11] + prompt_messages[i] = UserPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", user_message.content) + .replace("{{block}}", response_format) + ) + prompt_messages[i].content += f"Assistant:\n```{response_format}\n" + else: + prompt_messages[i] = UserPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", user_message.content) + .replace("{{block}}", response_format) + ) + prompt_messages[i].content += f"\n```{response_format}\n" + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 7ae8b87764..405f93498e 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -13,6 +13,7 @@ from dashscope.common.error import ( ) from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + + def _code_block_mode_wrapper(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \ + -> LLMResult | Generator: + """ + Wrapper for code block mode + """ + block_prompts = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + code_block = model_parameters.get("response_format", "") + if not code_block: + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + model_parameters.pop("response_format") + stop = stop or [] + stop.extend(["\n```", "```\n"]) + block_prompts = block_prompts.replace("{{block}}", code_block) + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", prompt_messages[0].content) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", f"Please output a valid {code_block} object.") + )) + + mode = self.get_model_mode(model, credentials) + if mode == LLMMode.CHAT: + if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # add ```JSON\n to the last message + prompt_messages[-1].content += f"\n```{code_block}\n" + else: + # append a user message + prompt_messages.append(UserPromptMessage( + content=f"```{code_block}\n" + )) + else: + prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n")) + + response = self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + if isinstance(response, Generator): + return self._code_block_mode_stream_processor_with_backtick( + model=model, + prompt_messages=prompt_messages, + input_generator=response + ) + + return response def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs['stop'] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): params = { 'model': model, **model_parameters, - **credentials_kwargs + **credentials_kwargs, + **extra_model_kwargs, } mode = self.get_model_mode(model, credentials) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml index 11eca82736..3461863e67 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml @@ -57,3 +57,5 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml index 58aab20004..9089c5904a 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml @@ -57,3 +57,5 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml index ccfa2356c3..eb1e8ac09b 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml @@ -57,3 +57,5 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml index 1dd13a1a26..83640371f9 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml @@ -56,6 +56,8 @@ parameter_rules: help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format pricing: input: '0.02' output: '0.02' diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml index 8da184ec9e..5455555bbd 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml @@ -57,6 +57,8 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.008' diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml index 0439506817..de9249ea34 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml @@ -25,6 +25,8 @@ parameter_rules: use_template: presence_penalty - name: frequency_penalty use_template: frequency_penalty + - name: response_format + use_template: response_format - name: disable_search label: zh_Hans: 禁用搜索 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml index fe06eb9975..b709644628 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml @@ -25,6 +25,8 @@ parameter_rules: use_template: presence_penalty - name: frequency_penalty use_template: frequency_penalty + - name: response_format + use_template: response_format - name: disable_search label: zh_Hans: 禁用搜索 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml index bcd9d1235b..2769c214e0 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml @@ -25,3 +25,5 @@ parameter_rules: use_template: presence_penalty - name: frequency_penalty use_template: frequency_penalty + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml index 75fb3b1942..5b1237b243 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml @@ -34,3 +34,5 @@ parameter_rules: zh_Hans: 禁用模型自行进行外部搜索。 en_US: Disable the model to perform external search. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 51b3c97497..d39d63deee 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -1,6 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Optional, Union, cast +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( RateLimitReachedError, ) +ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. -class ErnieBotLarguageModel(LargeLanguageModel): + +{{instructions}} + + +You should also complete the text started with ``` but not tell ``` directly. +""" + +class ErnieBotLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, @@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel): return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + response_format = model_parameters['response_format'] + stop = stop or [] + self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) + model_parameters.pop('response_format') + if stream: + return self._code_block_mode_stream_processor( + model=model, + prompt_messages=prompt_messages, + input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, + model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + ) + + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _transform_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts to model prompts + """ + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) + )) + + if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # add ```JSON\n to the last message + prompt_messages[-1].content += "\n```JSON\n{\n" + else: + # append a user message + prompt_messages.append(UserPromptMessage( + content="```JSON\n{\n" + )) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None) -> int: # tools is not supported yet diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index c62422dfb0..27277164c9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk from core.model_runtime.utils import helper +GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object. +The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + +And you should always end the block with a "```" to indicate the end of the JSON object. + + +{{instructions}} + + +```JSON""" class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): @@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # invoke model + # stop = stop or [] + # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # stream: bool = True, user: str | None = None) \ + # -> None: + # """ + # Transform json prompts to model prompts + # """ + # if "}\n\n" not in stop: + # stop.append("}\n\n") + + # # check if there is a system message + # if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # # override the system message + # prompt_messages[0] = SystemPromptMessage( + # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content) + # ) + # else: + # # insert the system message + # prompt_messages.insert(0, SystemPromptMessage( + # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.") + # )) + # # check if the last message is a user message + # if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # # add ```JSON\n to the last message + # prompt_messages[-1].content += "\n```JSON\n" + # else: + # # append a user message + # prompt_messages.append(UserPromptMessage( + # content="```JSON\n" + # )) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ @@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs['stop'] = stop client = ZhipuAI( api_key=credentials_kwargs['api_key'] @@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ] if stream: - response = client.chat.completions.create(stream=stream, **params) + response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) - response = client.chat.completions.create(**params) + response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) def _handle_generate_response(self, model: str, diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 1af21f147e..0d6c144929 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel +from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel def test_predefined_models(): - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() model_schemas = model.predefined_models() assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) def test_validate_credentials_for_chat_model(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( @@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model(): def test_invoke_model_ernie_bot(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot(): def test_invoke_model_ernie_bot_turbo(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot-turbo', @@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo(): def test_invoke_model_ernie_8k(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot-8k', @@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k(): def test_invoke_model_ernie_bot_4(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot-4', @@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4(): def test_invoke_stream_model(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -182,7 +182,7 @@ def test_invoke_stream_model(): def test_invoke_model_with_system(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -212,7 +212,7 @@ def test_invoke_model_with_system(): def test_invoke_with_search(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -250,7 +250,7 @@ def test_invoke_with_search(): def test_get_num_tokens(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.get_num_tokens( model='ernie-bot',