From c720f831aff02ce8a3e2a1869cd010cb77e0d44b Mon Sep 17 00:00:00 2001 From: John Wang Date: Tue, 27 Jun 2023 15:30:38 +0800 Subject: [PATCH] feat: optimize template parse (#460) --- api/core/__init__.py | 4 --- api/core/completion.py | 47 ++++++++++++++------------- api/core/conversation_message_task.py | 7 ++-- api/core/generator/llm_generator.py | 7 ++-- api/core/prompt/prompt_builder.py | 13 ++++---- api/core/prompt/prompt_template.py | 41 +++++++++++++++++++++++ api/core/prompt/prompts.py | 8 ++--- 7 files changed, 83 insertions(+), 44 deletions(-) diff --git a/api/core/__init__.py b/api/core/__init__.py index d7e00f73fa..0b26044aa5 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -3,7 +3,6 @@ from typing import Optional import langchain from flask import Flask -from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING from pydantic import BaseModel from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler @@ -22,9 +21,6 @@ hosted_llm_credentials = HostedLLMCredentials() def init_app(app: Flask): - formatter = OneLineFormatter() - DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': langchain.verbose = True diff --git a/api/core/completion.py b/api/core/completion.py index a999d34c5c..38a81f2807 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -23,7 +23,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \ ReadOnlyConversationTokenDBStringBufferSharedMemory from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import OutLinePromptTemplate +from core.prompt.prompt_template import JinjaPromptTemplate from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT from models.model import App, AppModelConfig, Account, Conversation, Message @@ -35,6 +35,8 @@ class Completion: """ errors: ProviderTokenNotInitError """ + query = PromptBuilder.process_template(query) + memory = None if conversation: # get memory of conversation (read-only) @@ -141,18 +143,17 @@ class Completion: memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: # disable template string in query - query_params = OutLinePromptTemplate.from_template(template=query).input_variables - if query_params: - for query_param in query_params: - if query_param not in inputs: - inputs[query_param] = '{' + query_param + '}' + # query_params = JinjaPromptTemplate.from_template(template=query).input_variables + # if query_params: + # for query_param in query_params: + # if query_param not in inputs: + # inputs[query_param] = '{{' + query_param + '}}' - pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt if mode == 'completion': - prompt_template = OutLinePromptTemplate.from_template( + prompt_template = JinjaPromptTemplate.from_template( template=("""Use the following CONTEXT as your learned knowledge: [CONTEXT] -{context} +{{context}} [END CONTEXT] When answer to user: @@ -162,16 +163,16 @@ Avoid mentioning that you obtained the information from the context. And answer according to the language of the user's question. """ if chain_output else "") + (pre_prompt + "\n" if pre_prompt else "") - + "{query}\n" + + "{{query}}\n" ) if chain_output: inputs['context'] = chain_output - context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables - if context_params: - for context_param in context_params: - if context_param not in inputs: - inputs[context_param] = '{' + context_param + '}' + # context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables + # if context_params: + # for context_param in context_params: + # if context_param not in inputs: + # inputs[context_param] = '{{' + context_param + '}}' prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs} prompt_content = prompt_template.format( @@ -195,7 +196,7 @@ And answer according to the language of the user's question. if pre_prompt: pre_prompt_inputs = {k: inputs[k] for k in - OutLinePromptTemplate.from_template(template=pre_prompt).input_variables + JinjaPromptTemplate.from_template(template=pre_prompt).input_variables if k in inputs} if pre_prompt_inputs: @@ -205,7 +206,7 @@ And answer according to the language of the user's question. human_inputs['context'] = chain_output human_message_prompt += """Use the following CONTEXT as your learned knowledge. [CONTEXT] -{context} +{{context}} [END CONTEXT] When answer to user: @@ -218,7 +219,7 @@ And answer according to the language of the user's question. if pre_prompt: human_message_prompt += pre_prompt - query_prompt = "\nHuman: {query}\nAI: " + query_prompt = "\nHuman: {{query}}\nAI: " if memory: # append chat histories @@ -234,11 +235,11 @@ And answer according to the language of the user's question. histories = cls.get_history_messages_from_memory(memory, rest_tokens) # disable template string in query - histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables - if histories_params: - for histories_param in histories_params: - if histories_param not in human_inputs: - human_inputs[histories_param] = '{' + histories_param + '}' + # histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables + # if histories_params: + # for histories_param in histories_params: + # if histories_param not in human_inputs: + # human_inputs[histories_param] = '{{' + histories_param + '}}' human_message_prompt += "\n\n" + histories diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 6057e4b63b..43e58d63ce 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -10,7 +10,7 @@ from core.constant import llm_constant from core.llm.llm_builder import LLMBuilder from core.llm.provider.llm_provider_service import LLMProviderService from core.prompt.prompt_builder import PromptBuilder -from core.prompt.prompt_template import OutLinePromptTemplate +from core.prompt.prompt_template import JinjaPromptTemplate from events.message_event import message_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -78,7 +78,7 @@ class ConversationMessageTask: if self.mode == 'chat': introduction = self.app_model_config.opening_statement if introduction: - prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction)) + prompt_template = JinjaPromptTemplate.from_template(template=introduction) prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs} try: introduction = prompt_template.format(**prompt_inputs) @@ -86,8 +86,7 @@ class ConversationMessageTask: pass if self.app_model_config.pre_prompt: - pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt) - system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs) + system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) system_instruction = system_message.content llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) system_instruction_tokens = llm.get_messages_tokens([system_message]) diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 97f4e05ddb..f23b874ae9 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -1,5 +1,6 @@ import logging +from langchain import PromptTemplate from langchain.chat_models.base import BaseChatModel from langchain.schema import HumanMessage, OutputParserException @@ -10,7 +11,7 @@ from core.llm.token_calculator import TokenCalculator from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser -from core.prompt.prompt_template import OutLinePromptTemplate +from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT @@ -91,8 +92,8 @@ class LLMGenerator: output_parser = SuggestedQuestionsAfterAnswerOutputParser() format_instructions = output_parser.get_format_instructions() - prompt = OutLinePromptTemplate( - template="{histories}\n{format_instructions}\nquestions:\n", + prompt = JinjaPromptTemplate( + template="{{histories}}\n{{format_instructions}}\nquestions:\n", input_variables=["histories"], partial_variables={"format_instructions": format_instructions} ) diff --git a/api/core/prompt/prompt_builder.py b/api/core/prompt/prompt_builder.py index 164bf6375b..073cf2ce25 100644 --- a/api/core/prompt/prompt_builder.py +++ b/api/core/prompt/prompt_builder.py @@ -3,13 +3,13 @@ import re from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate from langchain.schema import BaseMessage -from core.prompt.prompt_template import OutLinePromptTemplate +from core.prompt.prompt_template import JinjaPromptTemplate class PromptBuilder: @classmethod def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: - prompt_template = OutLinePromptTemplate.from_template(prompt_content) + prompt_template = JinjaPromptTemplate.from_template(prompt_content) system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template) prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs} system_message = system_prompt_template.format(**prompt_inputs) @@ -17,7 +17,7 @@ class PromptBuilder: @classmethod def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: - prompt_template = OutLinePromptTemplate.from_template(prompt_content) + prompt_template = JinjaPromptTemplate.from_template(prompt_content) ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template) prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs} ai_message = ai_prompt_template.format(**prompt_inputs) @@ -25,13 +25,14 @@ class PromptBuilder: @classmethod def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage: - prompt_template = OutLinePromptTemplate.from_template(prompt_content) + prompt_template = JinjaPromptTemplate.from_template(prompt_content) human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template) human_message = human_prompt_template.format(**inputs) return human_message @classmethod def process_template(cls, template: str): - processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template) - processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template) + processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template) + # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template) + # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template) return processed_template diff --git a/api/core/prompt/prompt_template.py b/api/core/prompt/prompt_template.py index 6799c5a733..4337389786 100644 --- a/api/core/prompt/prompt_template.py +++ b/api/core/prompt/prompt_template.py @@ -1,10 +1,33 @@ import re from typing import Any +from jinja2 import Environment, meta from langchain import PromptTemplate from langchain.formatting import StrictFormatter +class JinjaPromptTemplate(PromptTemplate): + template_format: str = "jinja2" + """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" + + @classmethod + def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: + """Load a prompt template from a template.""" + env = Environment() + ast = env.parse(template) + input_variables = meta.find_undeclared_variables(ast) + + if "partial_variables" in kwargs: + partial_variables = kwargs["partial_variables"] + input_variables = { + var for var in input_variables if var not in partial_variables + } + + return cls( + input_variables=list(sorted(input_variables)), template=template, **kwargs + ) + + class OutLinePromptTemplate(PromptTemplate): @classmethod def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate: @@ -16,6 +39,24 @@ class OutLinePromptTemplate(PromptTemplate): input_variables=list(sorted(input_variables)), template=template, **kwargs ) + def format(self, **kwargs: Any) -> str: + """Format the prompt with the inputs. + + Args: + kwargs: Any arguments to be passed to the prompt template. + + Returns: + A formatted string. + + Example: + + .. code-block:: python + + prompt.format(variable1="foo") + """ + kwargs = self._merge_partial_and_user_variables(**kwargs) + return OneLineFormatter().format(self.template, **kwargs) + class OneLineFormatter(StrictFormatter): def parse(self, format_string): diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index 330f473125..71a54b7ea1 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -1,5 +1,5 @@ CONVERSATION_TITLE_PROMPT = ( - "Human:{query}\n-----\n" + "Human:{{query}}\n-----\n" "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n" "If the human said is conducted in Chinese, you should return a Chinese title.\n" "If the human said is conducted in English, you should return an English title.\n" @@ -19,7 +19,7 @@ CONVERSATION_SUMMARY_PROMPT = ( INTRODUCTION_GENERATE_PROMPT = ( "I am designing a product for users to interact with an AI through dialogue. " "The Prompt given to the AI before the conversation is:\n\n" - "```\n{prompt}\n```\n\n" + "```\n{{prompt}}\n```\n\n" "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. " "Do not reveal the developer's motivation or deep logic behind the Prompt, " "but focus on building a relationship with the user:\n" @@ -27,13 +27,13 @@ INTRODUCTION_GENERATE_PROMPT = ( MORE_LIKE_THIS_GENERATE_PROMPT = ( "-----\n" - "{original_completion}\n" + "{{original_completion}}\n" "-----\n\n" "Please use the above content as a sample for generating the result, " "and include key information points related to the original sample in the result. " "Try to rephrase this information in different ways and predict according to the rules below.\n\n" "-----\n" - "{prompt}\n" + "{{prompt}}\n" ) SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (