diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py
index 856f4ce7d1..776f6802e6 100644
--- a/api/core/model_runtime/entities/defaults.py
+++ b/api/core/model_runtime/entities/defaults.py
@@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
'min': 1,
'max': 2048,
'precision': 0,
+ },
+ DefaultParameterName.RESPONSE_FORMAT: {
+ 'label': {
+ 'en_US': 'Response Format',
+ 'zh_Hans': '回复格式',
+ },
+ 'type': 'string',
+ 'help': {
+ 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
+ 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
+ },
+ 'required': False,
+ 'options': ['JSON', 'XML'],
}
}
\ No newline at end of file
diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py
index e35be27f86..52c2d66f9f 100644
--- a/api/core/model_runtime/entities/model_entities.py
+++ b/api/core/model_runtime/entities/model_entities.py
@@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
PRESENCE_PENALTY = "presence_penalty"
FREQUENCY_PENALTY = "frequency_penalty"
MAX_TOKENS = "max_tokens"
+ RESPONSE_FORMAT = "response_format"
@classmethod
def value_of(cls, value: Any) -> 'DefaultParameterName':
diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py
index a9f7a539e2..026e6eca21 100644
--- a/api/core/model_runtime/model_providers/__base/ai_model.py
+++ b/api/core/model_runtime/model_providers/__base/ai_model.py
@@ -262,23 +262,23 @@ class AIModel(ABC):
try:
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
- if not parameter_rule.max:
+ if not parameter_rule.max and 'max' in default_parameter_rule:
parameter_rule.max = default_parameter_rule['max']
- if not parameter_rule.min:
+ if not parameter_rule.min and 'min' in default_parameter_rule:
parameter_rule.min = default_parameter_rule['min']
- if not parameter_rule.precision:
+ if not parameter_rule.default and 'default' in default_parameter_rule:
parameter_rule.default = default_parameter_rule['default']
- if not parameter_rule.precision:
+ if not parameter_rule.precision and 'precision' in default_parameter_rule:
parameter_rule.precision = default_parameter_rule['precision']
- if not parameter_rule.required:
+ if not parameter_rule.required and 'required' in default_parameter_rule:
parameter_rule.required = default_parameter_rule['required']
- if not parameter_rule.help:
+ if not parameter_rule.help and 'help' in default_parameter_rule:
parameter_rule.help = I18nObject(
en_US=default_parameter_rule['help']['en_US'],
)
- if not parameter_rule.help.en_US:
+ if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
- if not parameter_rule.help.zh_Hans:
+ if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
except ValueError:
pass
diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py
index 1f7edd245f..4b546a5356 100644
--- a/api/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/api/core/model_runtime/model_providers/__base/large_language_model.py
@@ -9,7 +9,13 @@ from typing import Optional, Union
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
-from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ PromptMessageTool,
+ SystemPromptMessage,
+ UserPromptMessage,
+)
from core.model_runtime.entities.model_entities import (
ModelPropertyKey,
ModelType,
@@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel):
)
try:
- result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+ if "response_format" in model_parameters:
+ result = self._code_block_mode_wrapper(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ callbacks=callbacks
+ )
+ else:
+ result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
except Exception as e:
self._trigger_invoke_error_callbacks(
model=model,
@@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel):
return result
+ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
+ callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
+ """
+ Code block mode wrapper, ensure the response is a code block with output markdown quote
+
+ :param model: model name
+ :param credentials: model credentials
+ :param prompt_messages: prompt messages
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :param callbacks: callbacks
+ :return: full response or stream response chunk generator result
+ """
+
+ block_prompts = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+"""
+
+ code_block = model_parameters.get("response_format", "")
+ if not code_block:
+ return self._invoke(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user
+ )
+
+ model_parameters.pop("response_format")
+ stop = stop or []
+ stop.extend(["\n```", "```\n"])
+ block_prompts = block_prompts.replace("{{block}}", code_block)
+
+ # check if there is a system message
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+ # override the system message
+ prompt_messages[0] = SystemPromptMessage(
+ content=block_prompts
+ .replace("{{instructions}}", prompt_messages[0].content)
+ )
+ else:
+ # insert the system message
+ prompt_messages.insert(0, SystemPromptMessage(
+ content=block_prompts
+ .replace("{{instructions}}", f"Please output a valid {code_block} object.")
+ ))
+
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
+ # add ```JSON\n to the last message
+ prompt_messages[-1].content += f"\n```{code_block}\n"
+ else:
+ # append a user message
+ prompt_messages.append(UserPromptMessage(
+ content=f"```{code_block}\n"
+ ))
+
+ response = self._invoke(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user
+ )
+
+ if isinstance(response, Generator):
+ first_chunk = next(response)
+ def new_generator():
+ yield first_chunk
+ yield from response
+
+ if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
+ return self._code_block_mode_stream_processor_with_backtick(
+ model=model,
+ prompt_messages=prompt_messages,
+ input_generator=new_generator()
+ )
+ else:
+ return self._code_block_mode_stream_processor(
+ model=model,
+ prompt_messages=prompt_messages,
+ input_generator=new_generator()
+ )
+
+ return response
+
+ def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
+ input_generator: Generator[LLMResultChunk, None, None]
+ ) -> Generator[LLMResultChunk, None, None]:
+ """
+ Code block mode stream processor, ensure the response is a code block with output markdown quote
+
+ :param model: model name
+ :param prompt_messages: prompt messages
+ :param input_generator: input generator
+ :return: output generator
+ """
+ state = "normal"
+ backtick_count = 0
+ for piece in input_generator:
+ if piece.delta.message.content:
+ content = piece.delta.message.content
+ piece.delta.message.content = ""
+ yield piece
+ piece = content
+ else:
+ yield piece
+ continue
+ new_piece = ""
+ for char in piece:
+ if state == "normal":
+ if char == "`":
+ state = "in_backticks"
+ backtick_count = 1
+ else:
+ new_piece += char
+ elif state == "in_backticks":
+ if char == "`":
+ backtick_count += 1
+ if backtick_count == 3:
+ state = "skip_content"
+ backtick_count = 0
+ else:
+ new_piece += "`" * backtick_count + char
+ state = "normal"
+ backtick_count = 0
+ elif state == "skip_content":
+ if char.isspace():
+ state = "normal"
+
+ if new_piece:
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(
+ content=new_piece,
+ tool_calls=[]
+ ),
+ )
+ )
+
+ def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
+ input_generator: Generator[LLMResultChunk, None, None]) \
+ -> Generator[LLMResultChunk, None, None]:
+ """
+ Code block mode stream processor, ensure the response is a code block with output markdown quote.
+ This version skips the language identifier that follows the opening triple backticks.
+
+ :param model: model name
+ :param prompt_messages: prompt messages
+ :param input_generator: input generator
+ :return: output generator
+ """
+ state = "search_start"
+ backtick_count = 0
+
+ for piece in input_generator:
+ if piece.delta.message.content:
+ content = piece.delta.message.content
+ # Reset content to ensure we're only processing and yielding the relevant parts
+ piece.delta.message.content = ""
+ # Yield a piece with cleared content before processing it to maintain the generator structure
+ yield piece
+ piece = content
+ else:
+ # Yield pieces without content directly
+ yield piece
+ continue
+
+ if state == "done":
+ continue
+
+ new_piece = ""
+ for char in piece:
+ if state == "search_start":
+ if char == "`":
+ backtick_count += 1
+ if backtick_count == 3:
+ state = "skip_language"
+ backtick_count = 0
+ else:
+ backtick_count = 0
+ elif state == "skip_language":
+ # Skip everything until the first newline, marking the end of the language identifier
+ if char == "\n":
+ state = "in_code_block"
+ elif state == "in_code_block":
+ if char == "`":
+ backtick_count += 1
+ if backtick_count == 3:
+ state = "done"
+ break
+ else:
+ if backtick_count > 0:
+ # If backticks were counted but we're still collecting content, it was a false start
+ new_piece += "`" * backtick_count
+ backtick_count = 0
+ new_piece += char
+
+ elif state == "done":
+ break
+
+ if new_piece:
+ # Only yield content collected within the code block
+ yield LLMResultChunk(
+ model=model,
+ prompt_messages=prompt_messages,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(
+ content=new_piece,
+ tool_calls=[]
+ ),
+ )
+ )
+
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None,
@@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel):
:return: full response or stream response chunk generator result
"""
raise NotImplementedError
-
+
@abstractmethod
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml
index 08beef3caa..6707c34594 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml
+++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml
@@ -27,6 +27,8 @@ parameter_rules:
default: 4096
min: 1
max: 4096
+ - name: response_format
+ use_template: response_format
pricing:
input: '8.00'
output: '24.00'
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml
index 3c49067630..12faf60bc9 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml
+++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml
@@ -27,6 +27,8 @@ parameter_rules:
default: 4096
min: 1
max: 4096
+ - name: response_format
+ use_template: response_format
pricing:
input: '8.00'
output: '24.00'
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml
index d44859faa3..25d32a09af 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml
+++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml
@@ -26,6 +26,8 @@ parameter_rules:
default: 4096
min: 1
max: 4096
+ - name: response_format
+ use_template: response_format
pricing:
input: '1.63'
output: '5.51'
diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py
index c743708896..00e5ef6fda 100644
--- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py
+++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py
@@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
from anthropic.types import Completion, completion_create_params
from httpx import Timeout
+from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
+ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+"""
class AnthropicLargeLanguageModel(LargeLanguageModel):
-
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
@@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+
+ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
+ callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
+ """
+ Code block mode wrapper for invoking large language model
+ """
+ if 'response_format' in model_parameters and model_parameters['response_format']:
+ stop = stop or []
+ self._transform_json_prompts(
+ model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
+ )
+ model_parameters.pop('response_format')
+
+ return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def _transform_json_prompts(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
+ stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
+ -> None:
+ """
+ Transform json prompts
+ """
+ if "```\n" not in stop:
+ stop.append("```\n")
+
+ # check if there is a system message
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+ # override the system message
+ prompt_messages[0] = SystemPromptMessage(
+ content=ANTHROPIC_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", prompt_messages[0].content)
+ .replace("{{block}}", response_format)
+ )
+ else:
+ # insert the system message
+ prompt_messages.insert(0, SystemPromptMessage(
+ content=ANTHROPIC_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", f"Please output a valid {response_format} object.")
+ .replace("{{block}}", response_format)
+ ))
+
+ prompt_messages.append(AssistantPromptMessage(
+ content=f"```{response_format}\n"
+ ))
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml
index 3b98e615e6..ffdc9c3659 100644
--- a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml
+++ b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml
@@ -27,6 +27,8 @@ parameter_rules:
default: 2048
min: 1
max: 2048
+ - name: response_format
+ use_template: response_format
pricing:
input: '0.00'
output: '0.00'
diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py
index 686761ab5f..2feff8ebe9 100644
--- a/api/core/model_runtime/model_providers/google/llm/llm.py
+++ b/api/core/model_runtime/model_providers/google/llm/llm.py
@@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
logger = logging.getLogger(__name__)
+GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+"""
+
+
class GoogleLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
@@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
-
+
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml
index 3e40db01f9..c1602b2efc 100644
--- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml
@@ -24,6 +24,18 @@ parameter_rules:
default: 512
min: 1
max: 4096
+ - name: response_format
+ label:
+ zh_Hans: 回复格式
+ en_US: response_format
+ type: string
+ help:
+ zh_Hans: 指定模型必须输出的格式
+ en_US: specifying the format that the model must output
+ required: false
+ options:
+ - text
+ - json_object
pricing:
input: '0.0005'
output: '0.0015'
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
index 6d519cbee6..31dc53e89f 100644
--- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml
@@ -24,6 +24,8 @@ parameter_rules:
default: 512
min: 1
max: 4096
+ - name: response_format
+ use_template: response_format
pricing:
input: '0.0015'
output: '0.002'
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
index 499792e39d..56ab965c39 100644
--- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml
@@ -24,6 +24,18 @@ parameter_rules:
default: 512
min: 1
max: 4096
+ - name: response_format
+ label:
+ zh_Hans: 回复格式
+ en_US: response_format
+ type: string
+ help:
+ zh_Hans: 指定模型必须输出的格式
+ en_US: specifying the format that the model must output
+ required: false
+ options:
+ - text
+ - json_object
pricing:
input: '0.001'
output: '0.002'
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
index a86bacb34f..4a0e2ef191 100644
--- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml
@@ -24,6 +24,8 @@ parameter_rules:
default: 512
min: 1
max: 16385
+ - name: response_format
+ use_template: response_format
pricing:
input: '0.003'
output: '0.004'
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
index 467041e842..3684c1945c 100644
--- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml
@@ -24,6 +24,8 @@ parameter_rules:
default: 512
min: 1
max: 16385
+ - name: response_format
+ use_template: response_format
pricing:
input: '0.003'
output: '0.004'
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml
index 926ee05d97..ad831539e0 100644
--- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml
@@ -21,6 +21,8 @@ parameter_rules:
default: 512
min: 1
max: 4096
+ - name: response_format
+ use_template: response_format
pricing:
input: '0.0015'
output: '0.002'
diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
index fddf1836c4..4ffd31a814 100644
--- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
+++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml
@@ -24,6 +24,18 @@ parameter_rules:
default: 512
min: 1
max: 4096
+ - name: response_format
+ label:
+ zh_Hans: 回复格式
+ en_US: response_format
+ type: string
+ help:
+ zh_Hans: 指定模型必须输出的格式
+ en_US: specifying the format that the model must output
+ required: false
+ options:
+ - text
+ - json_object
pricing:
input: '0.001'
output: '0.002'
diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py
index 2a1137d443..2ea65780f1 100644
--- a/api/core/model_runtime/model_providers/openai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/openai/llm/llm.py
@@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import FunctionCall
+from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
logger = logging.getLogger(__name__)
+OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+"""
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
"""
@@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
user=user
)
+ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
+ callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
+ """
+ Code block mode wrapper for invoking large language model
+ """
+ # handle fine tune remote models
+ base_model = model
+ if model.startswith('ft:'):
+ base_model = model.split(':')[1]
+
+ # get model mode
+ model_mode = self.get_model_mode(base_model, credentials)
+
+ # transform response format
+ if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
+ stop = stop or []
+ if model_mode == LLMMode.CHAT:
+ # chat model
+ self._transform_chat_json_prompts(
+ model=base_model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ response_format=model_parameters['response_format']
+ )
+ else:
+ self._transform_completion_json_prompts(
+ model=base_model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ response_format=model_parameters['response_format']
+ )
+ model_parameters.pop('response_format')
+
+ return self._invoke(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user
+ )
+
+ def _transform_chat_json_prompts(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
+ stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
+ -> None:
+ """
+ Transform json prompts
+ """
+ if "```\n" not in stop:
+ stop.append("```\n")
+ if "\n```" not in stop:
+ stop.append("\n```")
+
+ # check if there is a system message
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+ # override the system message
+ prompt_messages[0] = SystemPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", prompt_messages[0].content)
+ .replace("{{block}}", response_format)
+ )
+ prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
+ else:
+ # insert the system message
+ prompt_messages.insert(0, SystemPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", f"Please output a valid {response_format} object.")
+ .replace("{{block}}", response_format)
+ ))
+ prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
+
+ def _transform_completion_json_prompts(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
+ stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
+ -> None:
+ """
+ Transform json prompts
+ """
+ if "```\n" not in stop:
+ stop.append("```\n")
+ if "\n```" not in stop:
+ stop.append("\n```")
+
+ # override the last user message
+ user_message = None
+ for i in range(len(prompt_messages) - 1, -1, -1):
+ if isinstance(prompt_messages[i], UserPromptMessage):
+ user_message = prompt_messages[i]
+ break
+
+ if user_message:
+ if prompt_messages[i].content[-11:] == 'Assistant: ':
+ # now we are in the chat app, remove the last assistant message
+ prompt_messages[i].content = prompt_messages[i].content[:-11]
+ prompt_messages[i] = UserPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", user_message.content)
+ .replace("{{block}}", response_format)
+ )
+ prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
+ else:
+ prompt_messages[i] = UserPromptMessage(
+ content=OPENAI_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", user_message.content)
+ .replace("{{block}}", response_format)
+ )
+ prompt_messages[i].content += f"\n```{response_format}\n"
+
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py
index 7ae8b87764..405f93498e 100644
--- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py
+++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py
@@ -13,6 +13,7 @@ from dashscope.common.error import (
)
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
+from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
"""
# invoke model
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
+
+ def _code_block_mode_wrapper(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
+ stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
+ -> LLMResult | Generator:
+ """
+ Wrapper for code block mode
+ """
+ block_prompts = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+
+{{instructions}}
+
+"""
+
+ code_block = model_parameters.get("response_format", "")
+ if not code_block:
+ return self._invoke(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user
+ )
+
+ model_parameters.pop("response_format")
+ stop = stop or []
+ stop.extend(["\n```", "```\n"])
+ block_prompts = block_prompts.replace("{{block}}", code_block)
+
+ # check if there is a system message
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+ # override the system message
+ prompt_messages[0] = SystemPromptMessage(
+ content=block_prompts
+ .replace("{{instructions}}", prompt_messages[0].content)
+ )
+ else:
+ # insert the system message
+ prompt_messages.insert(0, SystemPromptMessage(
+ content=block_prompts
+ .replace("{{instructions}}", f"Please output a valid {code_block} object.")
+ ))
+
+ mode = self.get_model_mode(model, credentials)
+ if mode == LLMMode.CHAT:
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
+ # add ```JSON\n to the last message
+ prompt_messages[-1].content += f"\n```{code_block}\n"
+ else:
+ # append a user message
+ prompt_messages.append(UserPromptMessage(
+ content=f"```{code_block}\n"
+ ))
+ else:
+ prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
+
+ response = self._invoke(
+ model=model,
+ credentials=credentials,
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user
+ )
+
+ if isinstance(response, Generator):
+ return self._code_block_mode_stream_processor_with_backtick(
+ model=model,
+ prompt_messages=prompt_messages,
+ input_generator=response
+ )
+
+ return response
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
@@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
"""
extra_model_kwargs = {}
if stop:
- extra_model_kwargs['stop_sequences'] = stop
+ extra_model_kwargs['stop'] = stop
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
@@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
params = {
'model': model,
**model_parameters,
- **credentials_kwargs
+ **credentials_kwargs,
+ **extra_model_kwargs,
}
mode = self.get_model_mode(model, credentials)
diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml
index 11eca82736..3461863e67 100644
--- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml
+++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml
@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
+ - name: response_format
+ use_template: response_format
diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml
index 58aab20004..9089c5904a 100644
--- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml
+++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml
@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
+ - name: response_format
+ use_template: response_format
diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml
index ccfa2356c3..eb1e8ac09b 100644
--- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml
+++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml
@@ -57,3 +57,5 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
+ - name: response_format
+ use_template: response_format
diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml
index 1dd13a1a26..83640371f9 100644
--- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml
+++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml
@@ -56,6 +56,8 @@ parameter_rules:
help:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
+ - name: response_format
+ use_template: response_format
pricing:
input: '0.02'
output: '0.02'
diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml
index 8da184ec9e..5455555bbd 100644
--- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml
+++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml
@@ -57,6 +57,8 @@ parameter_rules:
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
required: false
+ - name: response_format
+ use_template: response_format
pricing:
input: '0.008'
output: '0.008'
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml
index 0439506817..de9249ea34 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml
@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
+ - name: response_format
+ use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml
index fe06eb9975..b709644628 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml
@@ -25,6 +25,8 @@ parameter_rules:
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
+ - name: response_format
+ use_template: response_format
- name: disable_search
label:
zh_Hans: 禁用搜索
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml
index bcd9d1235b..2769c214e0 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml
@@ -25,3 +25,5 @@ parameter_rules:
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
+ - name: response_format
+ use_template: response_format
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml
index 75fb3b1942..5b1237b243 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml
+++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml
@@ -34,3 +34,5 @@ parameter_rules:
zh_Hans: 禁用模型自行进行外部搜索。
en_US: Disable the model to perform external search.
required: false
+ - name: response_format
+ use_template: response_format
diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py
index 51b3c97497..d39d63deee 100644
--- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py
+++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py
@@ -1,6 +1,7 @@
from collections.abc import Generator
-from typing import cast
+from typing import Optional, Union, cast
+from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
@@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
RateLimitReachedError,
)
+ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
+The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
-class ErnieBotLarguageModel(LargeLanguageModel):
+
+{{instructions}}
+
+
+You should also complete the text started with ``` but not tell ``` directly.
+"""
+
+class ErnieBotLargeLanguageModel(LargeLanguageModel):
def _invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: dict,
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
@@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel):
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
+ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
+ model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
+ stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
+ callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
+ """
+ Code block mode wrapper for invoking large language model
+ """
+ if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
+ response_format = model_parameters['response_format']
+ stop = stop or []
+ self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format)
+ model_parameters.pop('response_format')
+ if stream:
+ return self._code_block_mode_stream_processor(
+ model=model,
+ prompt_messages=prompt_messages,
+ input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages,
+ model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
+ )
+
+ return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
+
+ def _transform_json_prompts(self, model: str, credentials: dict,
+ prompt_messages: list[PromptMessage], model_parameters: dict,
+ tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
+ stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
+ -> None:
+ """
+ Transform json prompts to model prompts
+ """
+
+ # check if there is a system message
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+ # override the system message
+ prompt_messages[0] = SystemPromptMessage(
+ content=ERNIE_BOT_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", prompt_messages[0].content)
+ .replace("{{block}}", response_format)
+ )
+ else:
+ # insert the system message
+ prompt_messages.insert(0, SystemPromptMessage(
+ content=ERNIE_BOT_BLOCK_MODE_PROMPT
+ .replace("{{instructions}}", f"Please output a valid {response_format} object.")
+ .replace("{{block}}", response_format)
+ ))
+
+ if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
+ # add ```JSON\n to the last message
+ prompt_messages[-1].content += "\n```JSON\n{\n"
+ else:
+ # append a user message
+ prompt_messages.append(UserPromptMessage(
+ content="```JSON\n{\n"
+ ))
+
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool] | None = None) -> int:
# tools is not supported yet
diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py
index c62422dfb0..27277164c9 100644
--- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py
+++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py
@@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
from core.model_runtime.utils import helper
+GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
+The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
+if you are not sure about the structure.
+
+And you should always end the block with a "```" to indicate the end of the JSON object.
+
+
+{{instructions}}
+
+
+```JSON"""
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
@@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
credentials_kwargs = self._to_credential_kwargs(credentials)
# invoke model
+ # stop = stop or []
+ # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
+ # def _transform_json_prompts(self, model: str, credentials: dict,
+ # prompt_messages: list[PromptMessage], model_parameters: dict,
+ # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
+ # stream: bool = True, user: str | None = None) \
+ # -> None:
+ # """
+ # Transform json prompts to model prompts
+ # """
+ # if "}\n\n" not in stop:
+ # stop.append("}\n\n")
+
+ # # check if there is a system message
+ # if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
+ # # override the system message
+ # prompt_messages[0] = SystemPromptMessage(
+ # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
+ # )
+ # else:
+ # # insert the system message
+ # prompt_messages.insert(0, SystemPromptMessage(
+ # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
+ # ))
+ # # check if the last message is a user message
+ # if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
+ # # add ```JSON\n to the last message
+ # prompt_messages[-1].content += "\n```JSON\n"
+ # else:
+ # # append a user message
+ # prompt_messages.append(UserPromptMessage(
+ # content="```JSON\n"
+ # ))
+
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
@@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
"""
extra_model_kwargs = {}
if stop:
- extra_model_kwargs['stop_sequences'] = stop
+ extra_model_kwargs['stop'] = stop
client = ZhipuAI(
api_key=credentials_kwargs['api_key']
@@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
]
if stream:
- response = client.chat.completions.create(stream=stream, **params)
+ response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs)
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
- response = client.chat.completions.create(**params)
+ response = client.chat.completions.create(**params, **extra_model_kwargs)
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
def _handle_generate_response(self, model: str,
diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py
index 1af21f147e..0d6c144929 100644
--- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py
+++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py
@@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.errors.validate import CredentialsValidateFailedError
-from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel
+from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel
def test_predefined_models():
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
model_schemas = model.predefined_models()
assert len(model_schemas) >= 1
assert isinstance(model_schemas[0], AIModelEntity)
def test_validate_credentials_for_chat_model():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
with pytest.raises(CredentialsValidateFailedError):
model.validate_credentials(
@@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model():
def test_invoke_model_ernie_bot():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot',
@@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot():
def test_invoke_model_ernie_bot_turbo():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot-turbo',
@@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo():
def test_invoke_model_ernie_8k():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot-8k',
@@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k():
def test_invoke_model_ernie_bot_4():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot-4',
@@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4():
def test_invoke_stream_model():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot',
@@ -182,7 +182,7 @@ def test_invoke_stream_model():
def test_invoke_model_with_system():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot',
@@ -212,7 +212,7 @@ def test_invoke_model_with_system():
def test_invoke_with_search():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.invoke(
model='ernie-bot',
@@ -250,7 +250,7 @@ def test_invoke_with_search():
def test_get_num_tokens():
sleep(3)
- model = ErnieBotLarguageModel()
+ model = ErnieBotLargeLanguageModel()
response = model.get_num_tokens(
model='ernie-bot',