import json import re from collections.abc import Generator from typing import Literal, Union from core.application_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit from core.features.assistant_base_runner import BaseAssistantApplicationRunner from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.errors import ( ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError, ToolProviderCredentialValidationError, ToolProviderNotFoundError, ) from models.model import Conversation, Message class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): def run(self, conversation: Conversation, message: Message, query: str, inputs: dict[str, str], ) -> Union[Generator, LLMResult]: """ Run Cot agent application """ app_orchestration_config = self.app_orchestration_config self._repack_app_orchestration_config(app_orchestration_config) agent_scratchpad: list[AgentScratchpadUnit] = [] self._init_agent_scratchpad(agent_scratchpad, self.history_prompt_messages) # check model mode if self.app_orchestration_config.model_config.mode == "completion": # TODO: stop words if 'Observation' not in app_orchestration_config.model_config.stop: app_orchestration_config.model_config.stop.append('Observation') # override inputs inputs = inputs or {} instruction = self.app_orchestration_config.prompt_template.simple_prompt_template instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 prompt_messages = self.history_prompt_messages # convert tools into ModelRuntime Tool format prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: try: prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool) except Exception: # api tool may be deleted continue # save tool entity tool_instances[tool.tool_name] = tool_entity # save prompt tool prompt_messages_tools.append(prompt_tool) # convert dataset tools into ModelRuntime Tool format for dataset_tool in self.dataset_tools: prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool) # save prompt tool prompt_messages_tools.append(prompt_tool) # save tool entity tool_instances[dataset_tool.identity.name] = dataset_tool function_call_state = True llm_usage = { 'usage': None } final_answer = '' def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not final_llm_usage_dict['usage']: final_llm_usage_dict['usage'] = usage else: llm_usage = final_llm_usage_dict['usage'] llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.completion_tokens += usage.completion_tokens llm_usage.prompt_price += usage.prompt_price llm_usage.completion_price += usage.completion_price model_instance = self.model_instance while function_call_state and iteration_step <= max_iteration_steps: # continue to run until there is not any tool call function_call_state = False if iteration_step == max_iteration_steps: # the last iteration, remove all tools prompt_messages_tools = [] message_file_ids = [] agent_thought = self.create_agent_thought( message_id=message.id, message='', tool_name='', tool_input='', messages_ids=message_file_ids ) if iteration_step > 1: self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) # update prompt messages prompt_messages = self._organize_cot_prompt_messages( mode=app_orchestration_config.model_config.mode, prompt_messages=prompt_messages, tools=prompt_messages_tools, agent_scratchpad=agent_scratchpad, agent_prompt_message=app_orchestration_config.agent.prompt, instruction=instruction, input=query ) # recale llm max tokens self.recale_llm_max_tokens(self.model_config, prompt_messages) # invoke model chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_orchestration_config.model_config.parameters, tools=[], stop=app_orchestration_config.model_config.stop, stream=True, user=self.user_id, callbacks=[], ) # check llm result if not chunks: raise ValueError("failed to invoke llm") usage_dict = {} react_chunks = self._handle_stream_react(chunks, usage_dict) scratchpad = AgentScratchpadUnit( agent_response='', thought='', action_str='', observation='', action=None ) # publish agent thought if it's first iteration if iteration_step == 1: self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) for chunk in react_chunks: if isinstance(chunk, dict): scratchpad.agent_response += json.dumps(chunk) try: if scratchpad.action: raise Exception("") scratchpad.action_str = json.dumps(chunk) scratchpad.action = AgentScratchpadUnit.Action( action_name=chunk['action'], action_input=chunk['action_input'] ) except: scratchpad.thought += json.dumps(chunk) yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, system_fingerprint='', delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( content=json.dumps(chunk) ), usage=None ) ) else: scratchpad.agent_response += chunk scratchpad.thought += chunk yield LLMResultChunk( model=self.model_config.model, prompt_messages=prompt_messages, system_fingerprint='', delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( content=chunk ), usage=None ) ) agent_scratchpad.append(scratchpad) # get llm usage if 'usage' in usage_dict: increase_usage(llm_usage, usage_dict['usage']) else: usage_dict['usage'] = LLMUsage.empty_usage() self.save_agent_thought(agent_thought=agent_thought, tool_name=scratchpad.action.action_name if scratchpad.action else '', tool_input=scratchpad.action.action_input if scratchpad.action else '', thought=scratchpad.thought, observation='', answer=scratchpad.agent_response, messages_ids=[], llm_usage=usage_dict['usage']) if scratchpad.action and scratchpad.action.action_name.lower() != "final answer": self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) if not scratchpad.action: # failed to extract action, return final answer directly final_answer = scratchpad.agent_response or '' else: if scratchpad.action.action_name.lower() == "final answer": # action is final answer, return final answer directly try: final_answer = scratchpad.action.action_input if \ isinstance(scratchpad.action.action_input, str) else \ json.dumps(scratchpad.action.action_input) except json.JSONDecodeError: final_answer = f'{scratchpad.action.action_input}' else: function_call_state = True # action is tool call, invoke tool tool_call_name = scratchpad.action.action_name tool_call_args = scratchpad.action.action_input tool_instance = tool_instances.get(tool_call_name) if not tool_instance: answer = f"there is not a tool named {tool_call_name}" self.save_agent_thought(agent_thought=agent_thought, tool_name='', tool_input='', thought=None, observation=answer, answer=answer, messages_ids=[]) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) else: # invoke tool error_response = None try: tool_response = tool_instance.invoke( user_id=self.user_id, tool_parameters=tool_call_args if isinstance(tool_call_args, dict) else json.loads(tool_call_args) ) # transform tool response to llm friendly response tool_response = self.transform_tool_invoke_messages(tool_response) # extract binary data from tool invoke message binary_files = self.extract_tool_response_binary(tool_response) # create message file message_files = self.create_message_files(binary_files) # publish files for message_file, save_as in message_files: if save_as: self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as) self.queue_manager.publish_message_file(message_file, PublishFrom.APPLICATION_MANAGER) message_file_ids = [message_file.id for message_file, _ in message_files] except ToolProviderCredentialValidationError as e: error_response = "Please check your tool provider credentials" except ( ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError ) as e: error_response = f"there is not a tool named {tool_call_name}" except ( ToolParameterValidationError ) as e: error_response = f"tool parameters validation error: {e}, please check your tool parameters" except ToolInvokeError as e: error_response = f"tool invoke error: {e}" except Exception as e: error_response = f"unknown error: {e}" if error_response: observation = error_response else: observation = self._convert_tool_response_to_str(tool_response) # save scratchpad scratchpad.observation = observation # save agent thought self.save_agent_thought( agent_thought=agent_thought, tool_name=tool_call_name, tool_input=tool_call_args, thought=None, observation=observation, answer=scratchpad.agent_response, messages_ids=message_file_ids, ) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) # update prompt tool message for prompt_tool in prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) iteration_step += 1 yield LLMResultChunk( model=model_instance.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( content=final_answer ), usage=llm_usage['usage'] ), system_fingerprint='' ) # save agent thought self.save_agent_thought( agent_thought=agent_thought, tool_name='', tool_input='', thought=final_answer, observation='', answer=final_answer, messages_ids=[] ) self.update_db_variables(self.variables_pool, self.db_variables_pool) # publish end event self.queue_manager.publish_message_end(LLMResult( model=model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage( content=final_answer ), usage=llm_usage['usage'] if llm_usage['usage'] else LLMUsage.empty_usage(), system_fingerprint='' ), PublishFrom.APPLICATION_MANAGER) def _handle_stream_react(self, llm_response: Generator[LLMResultChunk, None, None], usage: dict) \ -> Generator[Union[str, dict], None, None]: def parse_json(json_str): try: return json.loads(json_str.strip()) except: return json_str def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, None]: code_blocks = re.findall(r'```(.*?)```', code_block, re.DOTALL) if not code_blocks: return for block in code_blocks: json_text = re.sub(r'^[a-zA-Z]+\n', '', block.strip(), flags=re.MULTILINE) yield parse_json(json_text) code_block_cache = '' code_block_delimiter_count = 0 in_code_block = False json_cache = '' json_quote_count = 0 in_json = False got_json = False for response in llm_response: response = response.delta.message.content if not isinstance(response, str): continue # stream index = 0 while index < len(response): steps = 1 delta = response[index:index+steps] if delta == '`': code_block_cache += delta code_block_delimiter_count += 1 else: if not in_code_block: if code_block_delimiter_count > 0: yield code_block_cache code_block_cache = '' else: code_block_cache += delta code_block_delimiter_count = 0 if code_block_delimiter_count == 3: if in_code_block: yield from extra_json_from_code_block(code_block_cache) code_block_cache = '' in_code_block = not in_code_block code_block_delimiter_count = 0 if not in_code_block: # handle single json if delta == '{': json_quote_count += 1 in_json = True json_cache += delta elif delta == '}': json_cache += delta if json_quote_count > 0: json_quote_count -= 1 if json_quote_count == 0: in_json = False got_json = True index += steps continue else: if in_json: json_cache += delta if got_json: got_json = False yield parse_json(json_cache) json_cache = '' json_quote_count = 0 in_json = False if not in_code_block and not in_json: yield delta.replace('`', '') index += steps if code_block_cache: yield code_block_cache if json_cache: yield parse_json(json_cache) def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: """ fill in inputs from external data tools """ for key, value in inputs.items(): try: instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) except Exception as e: continue return instruction def _init_agent_scratchpad(self, agent_scratchpad: list[AgentScratchpadUnit], messages: list[PromptMessage] ) -> list[AgentScratchpadUnit]: """ init agent scratchpad """ current_scratchpad: AgentScratchpadUnit = None for message in messages: if isinstance(message, AssistantPromptMessage): current_scratchpad = AgentScratchpadUnit( agent_response=message.content, thought=message.content, action_str='', action=None, observation=None ) if message.tool_calls: try: current_scratchpad.action = AgentScratchpadUnit.Action( action_name=message.tool_calls[0].function.name, action_input=json.loads(message.tool_calls[0].function.arguments) ) except: pass agent_scratchpad.append(current_scratchpad) elif isinstance(message, ToolPromptMessage): if current_scratchpad: current_scratchpad.observation = message.content return agent_scratchpad def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], agent_prompt_message: AgentPromptEntity, ): """ check chain of thought prompt messages, a standard prompt message is like: Respond to the human as helpfully and accurately as possible. {{instruction}} You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). Valid action values: "Final Answer" or {{tool_names}} Provide only ONE action per $JSON_BLOB, as shown: ``` { "action": $TOOL_NAME, "action_input": $ACTION_INPUT } ``` """ # parse agent prompt message first_prompt = agent_prompt_message.first_prompt next_iteration = agent_prompt_message.next_iteration if not isinstance(first_prompt, str) or not isinstance(next_iteration, str): raise ValueError("first_prompt or next_iteration is required in CoT agent mode") # check instruction, tools, and tool_names slots if not first_prompt.find("{{instruction}}") >= 0: raise ValueError("{{instruction}} is required in first_prompt") if not first_prompt.find("{{tools}}") >= 0: raise ValueError("{{tools}} is required in first_prompt") if not first_prompt.find("{{tool_names}}") >= 0: raise ValueError("{{tool_names}} is required in first_prompt") if mode == "completion": if not first_prompt.find("{{query}}") >= 0: raise ValueError("{{query}} is required in first_prompt") if not first_prompt.find("{{agent_scratchpad}}") >= 0: raise ValueError("{{agent_scratchpad}} is required in first_prompt") if mode == "completion": if not next_iteration.find("{{observation}}") >= 0: raise ValueError("{{observation}} is required in next_iteration") def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ convert agent scratchpad list to str """ next_iteration = self.app_orchestration_config.agent.prompt.next_iteration result = '' for scratchpad in agent_scratchpad: result += scratchpad.thought + next_iteration.replace("{{observation}}", scratchpad.observation or '') + "\n" return result def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], agent_scratchpad: list[AgentScratchpadUnit], agent_prompt_message: AgentPromptEntity, instruction: str, input: str, ) -> list[PromptMessage]: """ organize chain of thought prompt messages, a standard prompt message is like: Respond to the human as helpfully and accurately as possible. {{instruction}} You have access to the following tools: {{tools}} Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). Valid action values: "Final Answer" or {{tool_names}} Provide only ONE action per $JSON_BLOB, as shown: ``` {{{{ "action": $TOOL_NAME, "action_input": $ACTION_INPUT }}}} ``` """ self._check_cot_prompt_messages(mode, agent_prompt_message) # parse agent prompt message first_prompt = agent_prompt_message.first_prompt # parse tools tools_str = self._jsonify_tool_prompt_messages(tools) # parse tools name tool_names = '"' + '","'.join([tool.name for tool in tools]) + '"' # get system message system_message = first_prompt.replace("{{instruction}}", instruction) \ .replace("{{tools}}", tools_str) \ .replace("{{tool_names}}", tool_names) # organize prompt messages if mode == "chat": # override system message overridden = False prompt_messages = prompt_messages.copy() for prompt_message in prompt_messages: if isinstance(prompt_message, SystemPromptMessage): prompt_message.content = system_message overridden = True break if not overridden: prompt_messages.insert(0, SystemPromptMessage( content=system_message, )) # add assistant message if len(agent_scratchpad) > 0: prompt_messages.append(AssistantPromptMessage( content=(agent_scratchpad[-1].thought or '') )) # add user message if len(agent_scratchpad) > 0: prompt_messages.append(UserPromptMessage( content=(agent_scratchpad[-1].observation or ''), )) return prompt_messages elif mode == "completion": # parse agent scratchpad agent_scratchpad_str = self._convert_scratchpad_list_to_str(agent_scratchpad) # parse prompt messages return [UserPromptMessage( content=first_prompt.replace("{{instruction}}", instruction) .replace("{{tools}}", tools_str) .replace("{{tool_names}}", tool_names) .replace("{{query}}", input) .replace("{{agent_scratchpad}}", agent_scratchpad_str), )] else: raise ValueError(f"mode {mode} is not supported") def _jsonify_tool_prompt_messages(self, tools: list[PromptMessageTool]) -> str: """ jsonify tool prompt messages """ tools = jsonable_encoder(tools) try: return json.dumps(tools, ensure_ascii=False) except json.JSONDecodeError: return json.dumps(tools)