From 90150a6ca9bd7e45b5614eab95e2fc2ee909977e Mon Sep 17 00:00:00 2001 From: John Wang Date: Tue, 23 May 2023 12:26:28 +0800 Subject: [PATCH] Feat/optimize chat prompt (#158) --- api/core/completion.py | 69 ++++++++++++++++++++++-------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/api/core/completion.py b/api/core/completion.py index 8e79e03a93..afa40b45cd 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -39,7 +39,8 @@ class Completion: memory = cls.get_memory_from_conversation( tenant_id=app.tenant_id, app_model_config=app_model_config, - conversation=conversation + conversation=conversation, + return_messages=False ) inputs = conversation.inputs @@ -119,7 +120,8 @@ class Completion: return response @classmethod - def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str], + def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, + chain_output: Optional[str], memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ Union[str | List[BaseMessage]]: pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt @@ -161,11 +163,19 @@ And answer according to the language of the user's question. "query": query } - human_message_prompt = "{query}" + human_message_prompt = "" + + if pre_prompt: + pre_prompt_inputs = {k: inputs[k] for k in + OutLinePromptTemplate.from_template(template=pre_prompt).input_variables + if k in inputs} + + if pre_prompt_inputs: + human_inputs.update(pre_prompt_inputs) if chain_output: human_inputs['context'] = chain_output - human_message_instruction = """Use the following CONTEXT as your learned knowledge. + human_message_prompt += """Use the following CONTEXT as your learned knowledge. [CONTEXT] {context} [END CONTEXT] @@ -176,23 +186,27 @@ When answer to user: Avoid mentioning that you obtained the information from the context. And answer according to the language of the user's question. """ - if pre_prompt: - extra_inputs = {k: inputs[k] for k in - OutLinePromptTemplate.from_template(template=pre_prompt).input_variables - if k in inputs} - if extra_inputs: - human_inputs.update(extra_inputs) - human_message_instruction += pre_prompt + "\n" - human_message_prompt = human_message_instruction + "Q:{query}\nA:" - else: - if pre_prompt: - extra_inputs = {k: inputs[k] for k in - OutLinePromptTemplate.from_template(template=pre_prompt).input_variables - if k in inputs} - if extra_inputs: - human_inputs.update(extra_inputs) - human_message_prompt = pre_prompt + "\n" + human_message_prompt + if pre_prompt: + human_message_prompt += pre_prompt + + query_prompt = "\nHuman: {query}\nAI: " + + if memory: + # append chat histories + tmp_human_message = PromptBuilder.to_human_message( + prompt_content=human_message_prompt + query_prompt, + inputs=human_inputs + ) + + curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message]) + rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \ + - memory.llm.max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + history_messages = cls.get_history_messages_from_memory(memory, rest_tokens) + human_message_prompt += "\n\n" + history_messages + + human_message_prompt += query_prompt # construct main prompt human_message = PromptBuilder.to_human_message( @@ -200,23 +214,14 @@ And answer according to the language of the user's question. inputs=human_inputs ) - if memory: - # append chat histories - tmp_messages = messages.copy() + [human_message] - curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages) - rest_tokens = llm_constant.max_context_token_length[ - memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - history_messages = cls.get_history_messages_from_memory(memory, rest_tokens) - messages += history_messages - messages.append(human_message) return messages @classmethod def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI], - streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager: + streaming: bool, + conversation_message_task: ConversationMessageTask) -> CallbackManager: llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) if streaming: callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] @@ -228,7 +233,7 @@ And answer according to the language of the user's question. @classmethod def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, max_token_limit: int) -> \ - List[BaseMessage]: + str: """Get memory messages.""" memory.max_token_limit = max_token_limit memory_key = memory.memory_variables[0]