mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Feat/optimize chat prompt (#158)
This commit is contained in:
parent
7722a7c5cd
commit
90150a6ca9
|
@ -39,7 +39,8 @@ class Completion:
|
|||
memory = cls.get_memory_from_conversation(
|
||||
tenant_id=app.tenant_id,
|
||||
app_model_config=app_model_config,
|
||||
conversation=conversation
|
||||
conversation=conversation,
|
||||
return_messages=False
|
||||
)
|
||||
|
||||
inputs = conversation.inputs
|
||||
|
@ -119,7 +120,8 @@ class Completion:
|
|||
return response
|
||||
|
||||
@classmethod
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
|
||||
def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
|
||||
chain_output: Optional[str],
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
|
||||
Union[str | List[BaseMessage]]:
|
||||
pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
|
||||
|
@ -161,11 +163,19 @@ And answer according to the language of the user's question.
|
|||
"query": query
|
||||
}
|
||||
|
||||
human_message_prompt = "{query}"
|
||||
human_message_prompt = ""
|
||||
|
||||
if pre_prompt:
|
||||
pre_prompt_inputs = {k: inputs[k] for k in
|
||||
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
if k in inputs}
|
||||
|
||||
if pre_prompt_inputs:
|
||||
human_inputs.update(pre_prompt_inputs)
|
||||
|
||||
if chain_output:
|
||||
human_inputs['context'] = chain_output
|
||||
human_message_instruction = """Use the following CONTEXT as your learned knowledge.
|
||||
human_message_prompt += """Use the following CONTEXT as your learned knowledge.
|
||||
[CONTEXT]
|
||||
{context}
|
||||
[END CONTEXT]
|
||||
|
@ -176,23 +186,27 @@ When answer to user:
|
|||
Avoid mentioning that you obtained the information from the context.
|
||||
And answer according to the language of the user's question.
|
||||
"""
|
||||
if pre_prompt:
|
||||
extra_inputs = {k: inputs[k] for k in
|
||||
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
if k in inputs}
|
||||
if extra_inputs:
|
||||
human_inputs.update(extra_inputs)
|
||||
human_message_instruction += pre_prompt + "\n"
|
||||
|
||||
human_message_prompt = human_message_instruction + "Q:{query}\nA:"
|
||||
else:
|
||||
if pre_prompt:
|
||||
extra_inputs = {k: inputs[k] for k in
|
||||
OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
|
||||
if k in inputs}
|
||||
if extra_inputs:
|
||||
human_inputs.update(extra_inputs)
|
||||
human_message_prompt = pre_prompt + "\n" + human_message_prompt
|
||||
human_message_prompt += pre_prompt
|
||||
|
||||
query_prompt = "\nHuman: {query}\nAI: "
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=human_message_prompt + query_prompt,
|
||||
inputs=human_inputs
|
||||
)
|
||||
|
||||
curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
|
||||
rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
|
||||
- memory.llm.max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
human_message_prompt += "\n\n" + history_messages
|
||||
|
||||
human_message_prompt += query_prompt
|
||||
|
||||
# construct main prompt
|
||||
human_message = PromptBuilder.to_human_message(
|
||||
|
@ -200,23 +214,14 @@ And answer according to the language of the user's question.
|
|||
inputs=human_inputs
|
||||
)
|
||||
|
||||
if memory:
|
||||
# append chat histories
|
||||
tmp_messages = messages.copy() + [human_message]
|
||||
curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
|
||||
rest_tokens = llm_constant.max_context_token_length[
|
||||
memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
|
||||
messages += history_messages
|
||||
|
||||
messages.append(human_message)
|
||||
|
||||
return messages
|
||||
|
||||
@classmethod
|
||||
def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
|
||||
streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
|
||||
streaming: bool,
|
||||
conversation_message_task: ConversationMessageTask) -> CallbackManager:
|
||||
llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
|
||||
if streaming:
|
||||
callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
|
||||
|
@ -228,7 +233,7 @@ And answer according to the language of the user's question.
|
|||
@classmethod
|
||||
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
|
||||
max_token_limit: int) -> \
|
||||
List[BaseMessage]:
|
||||
str:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory_key = memory.memory_variables[0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user