feat: add baichuan prompt (#985)

This commit is contained in:
takatost 2023-08-24 10:22:36 +08:00 committed by GitHub
parent 9b247fccd4
commit 2c30d19cbe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 213 additions and 130 deletions

View File

@ -130,13 +130,12 @@ class Completion:
fake_response = agent_execute_result.output fake_response = agent_execute_result.output
# get llm prompt # get llm prompt
prompt_messages, stop_words = cls.get_main_llm_prompt( prompt_messages, stop_words = model_instance.get_prompt(
mode=mode, mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs, inputs=inputs,
agent_execute_result=agent_execute_result, query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory memory=memory
) )
@ -154,113 +153,6 @@ class Completion:
return response return response
@classmethod
def get_main_llm_prompt(cls, mode: str, model: dict,
pre_prompt: str, query: str, inputs: dict,
agent_execute_result: Optional[AgentExecuteResult],
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
if mode == 'completion':
prompt_template = JinjaPromptTemplate.from_template(
template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
""" if agent_execute_result else "")
+ (pre_prompt + "\n" if pre_prompt else "")
+ "{{query}}\n"
)
if agent_execute_result:
inputs['context'] = agent_execute_result.output
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_content = prompt_template.format(
query=query,
**prompt_inputs
)
return [PromptMessage(content=prompt_content)], None
else:
messages: List[BaseMessage] = []
human_inputs = {
"query": query
}
human_message_prompt = ""
if pre_prompt:
pre_prompt_inputs = {k: inputs[k] for k in
JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
if k in inputs}
if pre_prompt_inputs:
human_inputs.update(pre_prompt_inputs)
if agent_execute_result:
human_inputs['context'] = agent_execute_result.output
human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
{{context}}
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
"""
if pre_prompt:
human_message_prompt += pre_prompt
query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
if memory:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt + query_prompt,
inputs=human_inputs
)
if memory.model_instance.model_rules.max_tokens.max:
curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = model.get("completion_params").get('max_tokens')
rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
histories = cls.get_history_messages_from_memory(memory, rest_tokens)
human_message_prompt += "\n\n" if human_message_prompt else ""
human_message_prompt += "Here is the chat histories between human and assistant, " \
"inside <histories></histories> XML tags.\n\n<histories>\n"
human_message_prompt += histories + "\n</histories>"
human_message_prompt += query_prompt
# construct main prompt
human_message = PromptBuilder.to_human_message(
prompt_content=human_message_prompt,
inputs=human_inputs
)
messages.append(human_message)
for message in messages:
message.content = re.sub(r'<\|.*?\|>', '', message.content)
return to_prompt_messages(messages), ['\nHuman:', '</histories>']
@classmethod @classmethod
def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
max_token_limit: int) -> str: max_token_limit: int) -> str:
@ -307,13 +199,12 @@ And answer according to the language of the user's question.
max_tokens = 0 max_tokens = 0
# get prompt without memory and context # get prompt without memory and context
prompt_messages, _ = cls.get_main_llm_prompt( prompt_messages, _ = model_instance.get_prompt(
mode=mode, mode=mode,
model=app_model_config.model_dict,
pre_prompt=app_model_config.pre_prompt, pre_prompt=app_model_config.pre_prompt,
query=query,
inputs=inputs, inputs=inputs,
agent_execute_result=None, query=query,
context=None,
memory=None memory=None
) )
@ -358,13 +249,12 @@ And answer according to the language of the user's question.
) )
# get llm prompt # get llm prompt
old_prompt_messages, _ = cls.get_main_llm_prompt( old_prompt_messages, _ = final_model_instance.get_prompt(
mode="completion", mode='completion',
model=app_model_config.model_dict,
pre_prompt=pre_prompt, pre_prompt=pre_prompt,
query=message.query,
inputs=message.inputs, inputs=message.inputs,
agent_execute_result=None, query=message.query,
context=None,
memory=None memory=None
) )

View File

@ -1,17 +1,24 @@
import json
import os
import re
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Any, Union from typing import List, Optional, Any, Union, Tuple
import decimal import decimal
from langchain.callbacks.manager import Callbacks from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.model_providers.models.base import BaseProviderModel from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM from core.third_party.langchain.llms.fake import FakeLLM
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
def price_config(self) -> dict: def price_config(self) -> dict:
def get_or_default(): def get_or_default():
default_price_config = { default_price_config = {
'prompt': decimal.Decimal('0'), 'prompt': decimal.Decimal('0'),
'completion': decimal.Decimal('0'), 'completion': decimal.Decimal('0'),
'unit': decimal.Decimal('0'), 'unit': decimal.Decimal('0'),
'currency': 'USD' 'currency': 'USD'
} }
rules = self.model_provider.get_rules() rules = self.model_provider.get_rules()
price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config price_config = rules['price_config'][
self.base_model_name] if 'price_config' in rules else default_price_config
price_config = { price_config = {
'prompt': decimal.Decimal(price_config['prompt']), 'prompt': decimal.Decimal(price_config['prompt']),
'completion': decimal.Decimal(price_config['completion']), 'completion': decimal.Decimal(price_config['completion']),
@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
'currency': price_config['currency'] 'currency': price_config['currency']
} }
return price_config return price_config
self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default() self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
logger.debug(f"model: {self.name} price_config: {self._price_config}") logger.debug(f"model: {self.name} price_config: {self._price_config}")
@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
total_tokens = result.llm_output['token_usage']['total_tokens'] total_tokens = result.llm_output['token_usage']['total_tokens']
else: else:
prompt_tokens = self.get_num_tokens(messages) prompt_tokens = self.get_num_tokens(messages)
completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)]) completion_tokens = self.get_num_tokens(
[PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
self.model_provider.update_last_used() self.model_provider.update_last_used()
@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
def support_streaming(cls): def support_streaming(cls):
return False return False
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_prompt_from_messages(self, messages: List[PromptMessage], def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode: if not model_mode:

View File

@ -60,6 +60,15 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts) return self._client.get_num_tokens(prompts)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs self.client.model_kwargs = provider_model_kwargs

View File

@ -49,6 +49,15 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass pass

View File

@ -59,6 +59,15 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages) prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0) return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs): def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass pass

View File

@ -0,0 +1,13 @@
{
"human_prefix": "用户",
"assistant_prefix": "助手",
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n\n",
"histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt",
"histories_prompt"
],
"query_prompt": "用户:{{query}}\n助手",
"stops": ["用户:"]
}

View File

@ -0,0 +1,9 @@
{
"context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt"
],
"query_prompt": "{{query}}",
"stops": null
}

View File

@ -0,0 +1,13 @@
{
"human_prefix": "Human",
"assistant_prefix": "Assistant",
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
"histories_prompt": "Here is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{histories}}\n</histories>\n\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt",
"histories_prompt"
],
"query_prompt": "Human: {{query}}\n\nAssistant: ",
"stops": ["\nHuman:", "</histories>"]
}

View File

@ -0,0 +1,9 @@
{
"context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
"system_prompt_orders": [
"context_prompt",
"pre_prompt"
],
"query_prompt": "{{query}}",
"stops": null
}