Merge pull request #300 from RockChinQ/token-process

[Perf] Tokens相关处理逻辑优化
This commit is contained in:
Rock Chin 2023-03-19 16:35:25 +08:00 committed by GitHub
commit 5f83cc6bb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 132 additions and 44 deletions

View File

@ -82,7 +82,30 @@ default_prompt = {
# 情景预设格式
# 参考值旧版本方式default | 完整情景full_scenario
# 旧版本的格式为上述default_prompt中的内容或prompts目录下的文件名
# 完整情景预设的格式为JSON在JSON文件中列出对话的每个回合编写方法见scenario/default-template.json
#
# 完整情景预设的格式为JSON在scenario目录下的JSON文件中列出对话的每个回合编写方法见scenario/default-template.json
# 编写方法例如:
# {
# "prompt": [
# {
# "role": "user",
# "content": "之后当我需要帮助时,请说“输入!help获取帮助”"
# },{
# "role": "assistant",
# "content": "好的,当你之后需要帮助时,我会说“输入!help获取帮助”"
# },{
# "role": "user",
# "content": "帮助"
# },{
# "role": "assistant",
# "content": "输入!help获取帮助"
# }
# ]
# }
#
# 您可以按照上述格式编写自己的情景预设在prompt中列出对话的每个回合
# role为user或assistant分别表示用户和机器人的回复
# 每个JSON文件是一个情景预设文件名即为情景预设的名称
preset_mode = "default"
# 群内响应规则
@ -139,7 +162,7 @@ encourage_sponsor_at_start = True
# 每次向OpenAI接口发送对话记录上下文的字符数
# 最大不超过(4096 - max_tokens)个字符max_tokens为下方completion_api_params中的max_tokens
# 注意较大的prompt_submit_length会导致OpenAI账户额度消耗更快
prompt_submit_length = 1024
prompt_submit_length = 2048
# OpenAI补全API的参数
# 请在下方填写模型,程序自动选择接口

View File

@ -54,20 +54,27 @@ class DatabaseManager:
`last_interact_timestamp` bigint not null,
`status` varchar(255) not null default 'on_going',
`default_prompt` text not null default '',
`prompt` text not null
`prompt` text not null,
`token_counts` text not null default '[]'
)
""")
# 检查sessions表是否存在`default_prompt`字段
# 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段
self.__execute__("PRAGMA table_info('sessions')")
columns = self.cursor.fetchall()
has_default_prompt = False
has_token_counts = False
for field in columns:
if field[1] == 'default_prompt':
has_default_prompt = True
if field[1] == 'token_counts':
has_token_counts = True
if has_default_prompt and has_token_counts:
break
if not has_default_prompt:
self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''")
if not has_token_counts:
self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'")
self.__execute__("""
@ -89,7 +96,7 @@ class DatabaseManager:
# session持久化
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str, default_prompt: str = ''):
last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''):
"""持久化指定session"""
# 检查是否已经有了此name和create_timestamp的session
@ -102,20 +109,20 @@ class DatabaseManager:
if count == 0:
sql = """
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`)
values (?, ?, ?, ?, ?, ?, ?)
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`)
values (?, ?, ?, ?, ?, ?, ?, ?)
"""
self.__execute__(sql,
("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt, default_prompt))
last_interact_timestamp, prompt, default_prompt, token_counts))
else:
sql = """
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?
update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ?
where `type` = ? and `number` = ? and `create_timestamp` = ?
"""
self.__execute__(sql, (last_interact_timestamp, prompt, subject_type,
self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type,
subject_number, create_timestamp))
# 显式关闭一个session
@ -140,7 +147,7 @@ class DatabaseManager:
# 从数据库中加载所有还没过期的session
config = pkg.utils.context.get_config()
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall()
@ -154,6 +161,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
# 当且仅当最后一个该对象的会话是on_going状态时才会被加载
if status == 'on_going':
@ -163,7 +171,8 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': token_counts
}
else:
if session_name in sessions:
@ -175,7 +184,7 @@ class DatabaseManager:
def last_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc
limit 1
""".format(session_name, cursor_timestamp))
@ -192,6 +201,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
@ -199,14 +209,15 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 获取此session_name后一个session的数据
def next_session(self, session_name: str, cursor_timestamp: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc
limit 1
""".format(session_name, cursor_timestamp))
@ -223,6 +234,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
return {
'subject_type': subject_type,
@ -230,13 +242,14 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': token_counts
}
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int):
self.__execute__("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
""".format(session_name, capacity, capacity * page))
results = self.cursor.fetchall()
@ -250,6 +263,7 @@ class DatabaseManager:
prompt = result[5]
status = result[6]
default_prompt = result[7]
token_counts = result[8]
sessions.append({
'subject_type': subject_type,
@ -257,7 +271,8 @@ class DatabaseManager:
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt,
'default_prompt': default_prompt
'default_prompt': default_prompt,
'token_counts': token_counts
})
return sessions

View File

@ -34,7 +34,7 @@ class OpenAIInteract:
pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion
def request_completion(self, prompts) -> str:
def request_completion(self, prompts) -> tuple[str, int]:
"""请求补全接口回复
Parameters:
@ -60,14 +60,18 @@ class OpenAIInteract:
logging.debug("OpenAI response: %s", response)
# 记录使用量
current_round_token = 0
if 'model' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
ai.get_total_tokens())
current_round_token = ai.get_total_tokens()
elif 'engine' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
response['usage']['total_tokens'])
current_round_token = response['usage']['total_tokens']
return ai.get_message()
return ai.get_message(), current_round_token
def request_image(self, prompt) -> dict:
"""请求图片接口回复

View File

@ -72,6 +72,7 @@ def load_sessions():
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
try:
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
temp_session.token_counts = json.loads(session_data[session_name]['token_counts'])
except Exception:
temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt'])
temp_session.persistence()
@ -106,6 +107,9 @@ class Session:
prompt = []
"""使用list来保存会话中的回合"""
token_counts = []
"""每个回合的token数量"""
default_prompt = []
"""本session的默认prompt"""
@ -146,6 +150,8 @@ class Session:
self.name = name
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.prompt = []
self.token_counts = []
self.schedule()
self.response_lock = threading.Lock()
@ -209,9 +215,16 @@ class Session:
config = pkg.utils.context.get_config()
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
prompts, counts = self.cut_out(text, max_length)
# 计算请求前的prompt数量
total_token_before_query = 0
for token_count in counts:
total_token_before_query += token_count
# 向API请求补全
message = pkg.utils.context.get_openai_manager().request_completion(
self.cut_out(text, max_length),
message, total_token = pkg.utils.context.get_openai_manager().request_completion(
prompts,
)
# 成功获取,处理回复
@ -228,6 +241,10 @@ class Session:
self.prompt.append({'role': 'user', 'content': text})
self.prompt.append({'role': 'assistant', 'content': res_ans})
# 向token_counts中添加本回合的token数量
self.token_counts.append(total_token-total_token_before_query)
logging.debug("本回合使用token: {}, session counts: {}".format(total_token-total_token_before_query, self.token_counts))
if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False
self.set_ongoing()
@ -244,39 +261,65 @@ class Session:
question = self.prompt[-2]['content']
self.prompt = self.prompt[:-2]
self.token_counts = self.token_counts[:-1]
# 返回上一回合的问题
return question
# 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> list:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens"""
# 如果用户消息长度超过max_tokens直接返回
temp_prompt: list = []
temp_prompt += self.default_prompt
temp_prompt.append(
def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]:
"""将现有prompt进行切割处理使得新的prompt长度不超过max_tokens
:return: (新的prompt, 新的token_counts)
"""
# 最终由三个部分组成
# - default_prompt 情景预设固定值
# - changable_prompts 可变部分, 此会话中的历史对话回合
# - current_question 当前问题
# 包装目前的对话回合内容
changable_prompts = []
changable_counts = []
# 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1
changable_index = len(self.prompt) - 1
token_count_index = len(self.token_counts) - 1
packed_tokens = 0
print(self.prompt)
while changable_index >= 0 and token_count_index >= 0:
if packed_tokens + self.token_counts[token_count_index] > max_tokens:
break
changable_prompts.insert(0, self.prompt[changable_index])
changable_prompts.insert(0, self.prompt[changable_index - 1])
changable_counts.insert(0, self.token_counts[token_count_index])
packed_tokens += self.token_counts[token_count_index]
changable_index -= 2
token_count_index -= 1
# 将default_prompt和changable_prompts合并
result_prompt = self.default_prompt + changable_prompts
print(changable_prompts)
# 添加当前问题
result_prompt.append(
{
'role': 'user',
'content': msg
}
)
token_count = 0
for item in temp_prompt:
token_count += len(item['content'])
logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4),
packed_tokens,
changable_counts,
self.token_counts))
# 倒序遍历prompt
for i in range(len(self.prompt) - 1, -1, -1):
if token_count >= max_tokens:
break
# 将prompt加到temp_prompt倒数第二个位置
temp_prompt.insert(len(self.default_prompt), self.prompt[i])
token_count += len(self.prompt[i]['content'])
logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4)))
return temp_prompt
return result_prompt, changable_counts
# 持久化session
def persistence(self):
@ -291,7 +334,7 @@ class Session:
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
json.dumps(self.prompt), json.dumps(self.default_prompt))
json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts))
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
@ -314,6 +357,7 @@ class Session:
self.default_prompt = self.get_default_prompt(use_prompt)
self.prompt = []
self.token_counts = []
self.create_timestamp = int(time.time())
self.last_interact_timestamp = int(time.time())
self.just_switched_to_exist_session = False
@ -339,6 +383,7 @@ class Session:
self.last_interact_timestamp = last_one['last_interact_timestamp']
try:
self.prompt = json.loads(last_one['prompt'])
self.token_counts = json.loads(last_one['token_counts'])
except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, last_one['prompt'])
self.persistence()
@ -359,6 +404,7 @@ class Session:
self.last_interact_timestamp = next_one['last_interact_timestamp']
try:
self.prompt = json.loads(next_one['prompt'])
self.token_counts = json.loads(next_one['token_counts'])
except json.decoder.JSONDecodeError:
self.prompt = reset_session_prompt(self.name, next_one['prompt'])
self.persistence()

View File

View File