From 0490ad92079a7312530f4f3f9068b1e9053786e6 Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sat, 18 Mar 2023 11:26:18 +0000 Subject: [PATCH 1/4] =?UTF-8?q?test:=20token=E8=AE=A1=E6=95=B0=E6=B5=8B?= =?UTF-8?q?=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/token_test/__init__.py | 0 tests/token_test/token_count.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/token_test/__init__.py create mode 100644 tests/token_test/token_count.py diff --git a/tests/token_test/__init__.py b/tests/token_test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/token_test/token_count.py b/tests/token_test/token_count.py new file mode 100644 index 0000000..e69de29 From d056cb6769ef1870e5f8c42b787a64a03ce662da Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sat, 18 Mar 2023 12:57:36 +0000 Subject: [PATCH 2/4] =?UTF-8?q?feat:=20=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/database/manager.py | 47 +++++++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index 999d731..d76dc0c 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -54,20 +54,27 @@ class DatabaseManager: `last_interact_timestamp` bigint not null, `status` varchar(255) not null default 'on_going', `default_prompt` text not null default '', - `prompt` text not null + `prompt` text not null, + `token_counts` text not null default '[]', ) """) - # 检查sessions表是否存在`default_prompt`字段 + # 检查sessions表是否存在`default_prompt`字段, 检查是否存在`token_counts`字段 self.__execute__("PRAGMA table_info('sessions')") columns = self.cursor.fetchall() has_default_prompt = False + has_token_counts = False for field in columns: if field[1] == 'default_prompt': has_default_prompt = True + if field[1] == 'token_counts': + has_token_counts = True + if has_default_prompt and has_token_counts: break if not has_default_prompt: self.__execute__("alter table `sessions` add column `default_prompt` text not null default ''") + if not has_token_counts: + self.__execute__("alter table `sessions` add column `token_counts` text not null default '[]'") self.__execute__(""" @@ -89,7 +96,7 @@ class DatabaseManager: # session持久化 def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, - last_interact_timestamp: int, prompt: str, default_prompt: str = ''): + last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: list = []): """持久化指定session""" # 检查是否已经有了此name和create_timestamp的session @@ -102,20 +109,20 @@ class DatabaseManager: if count == 0: sql = """ - insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`) - values (?, ?, ?, ?, ?, ?, ?) + insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `default_prompt`, `token_counts`) + values (?, ?, ?, ?, ?, ?, ?, ?) """ self.__execute__(sql, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt, default_prompt)) + last_interact_timestamp, prompt, default_prompt, json.dumps(token_counts))) else: sql = """ - update `sessions` set `last_interact_timestamp` = ?, `prompt` = ? + update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ? where `type` = ? and `number` = ? and `create_timestamp` = ? """ - self.__execute__(sql, (last_interact_timestamp, prompt, subject_type, + self.__execute__(sql, (last_interact_timestamp, prompt, json.dumps(token_counts), subject_type, subject_number, create_timestamp)) # 显式关闭一个session @@ -140,7 +147,7 @@ class DatabaseManager: # 从数据库中加载所有还没过期的session config = pkg.utils.context.get_config() self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `last_interact_timestamp` > {} """.format(int(time.time()) - config.session_expire_time)) results = self.cursor.fetchall() @@ -154,6 +161,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] # 当且仅当最后一个该对象的会话是on_going状态时,才会被加载 if status == 'on_going': @@ -163,7 +171,8 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) } else: if session_name in sessions: @@ -175,7 +184,7 @@ class DatabaseManager: def last_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `name` = '{}' and `last_interact_timestamp` < {} order by `last_interact_timestamp` desc limit 1 """.format(session_name, cursor_timestamp)) @@ -192,6 +201,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] return { 'subject_type': subject_type, @@ -199,14 +209,15 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) } # 获取此session_name后一个session的数据 def next_session(self, session_name: str, cursor_timestamp: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `name` = '{}' and `last_interact_timestamp` > {} order by `last_interact_timestamp` asc limit 1 """.format(session_name, cursor_timestamp)) @@ -223,6 +234,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] return { 'subject_type': subject_type, @@ -230,13 +242,14 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) } # 列出与某个对象的所有对话session def list_history(self, session_name: str, capacity: int, page: int): self.__execute__(""" - select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt` + select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`, `default_prompt`, `token_counts` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} """.format(session_name, capacity, capacity * page)) results = self.cursor.fetchall() @@ -250,6 +263,7 @@ class DatabaseManager: prompt = result[5] status = result[6] default_prompt = result[7] + token_counts = result[8] sessions.append({ 'subject_type': subject_type, @@ -257,7 +271,8 @@ class DatabaseManager: 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, - 'default_prompt': default_prompt + 'default_prompt': default_prompt, + 'token_counts': json.loads(token_counts) }) return sessions From ca349e33fcbf9ac9366172b7f6412949bfd4958c Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sat, 18 Mar 2023 15:57:28 +0000 Subject: [PATCH 3/4] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E6=96=B0?= =?UTF-8?q?=E7=9A=84=E5=89=8D=E6=96=87=E5=89=AA=E5=88=87=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/database/manager.py | 16 +++---- pkg/openai/manager.py | 8 +++- pkg/openai/session.py | 94 ++++++++++++++++++++++++++++++----------- 3 files changed, 84 insertions(+), 34 deletions(-) diff --git a/pkg/database/manager.py b/pkg/database/manager.py index d76dc0c..33d6cfb 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -55,7 +55,7 @@ class DatabaseManager: `status` varchar(255) not null default 'on_going', `default_prompt` text not null default '', `prompt` text not null, - `token_counts` text not null default '[]', + `token_counts` text not null default '[]' ) """) @@ -96,7 +96,7 @@ class DatabaseManager: # session持久化 def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int, - last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: list = []): + last_interact_timestamp: int, prompt: str, default_prompt: str = '', token_counts: str = ''): """持久化指定session""" # 检查是否已经有了此name和create_timestamp的session @@ -115,14 +115,14 @@ class DatabaseManager: self.__execute__(sql, ("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp, - last_interact_timestamp, prompt, default_prompt, json.dumps(token_counts))) + last_interact_timestamp, prompt, default_prompt, token_counts)) else: sql = """ update `sessions` set `last_interact_timestamp` = ?, `prompt` = ?, `token_counts` = ? where `type` = ? and `number` = ? and `create_timestamp` = ? """ - self.__execute__(sql, (last_interact_timestamp, prompt, json.dumps(token_counts), subject_type, + self.__execute__(sql, (last_interact_timestamp, prompt, token_counts, subject_type, subject_number, create_timestamp)) # 显式关闭一个session @@ -172,7 +172,7 @@ class DatabaseManager: 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, 'default_prompt': default_prompt, - 'token_counts': json.loads(token_counts) + 'token_counts': token_counts } else: if session_name in sessions: @@ -210,7 +210,7 @@ class DatabaseManager: 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, 'default_prompt': default_prompt, - 'token_counts': json.loads(token_counts) + 'token_counts': token_counts } # 获取此session_name后一个session的数据 @@ -243,7 +243,7 @@ class DatabaseManager: 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, 'default_prompt': default_prompt, - 'token_counts': json.loads(token_counts) + 'token_counts': token_counts } # 列出与某个对象的所有对话session @@ -272,7 +272,7 @@ class DatabaseManager: 'last_interact_timestamp': last_interact_timestamp, 'prompt': prompt, 'default_prompt': default_prompt, - 'token_counts': json.loads(token_counts) + 'token_counts': token_counts }) return sessions diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index 4a3ceab..2d64e9a 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -34,7 +34,7 @@ class OpenAIInteract: pkg.utils.context.set_openai_manager(self) # 请求OpenAI Completion - def request_completion(self, prompts) -> str: + def request_completion(self, prompts) -> tuple[str, int]: """请求补全接口回复 Parameters: @@ -60,14 +60,18 @@ class OpenAIInteract: logging.debug("OpenAI response: %s", response) + # 记录使用量 + current_round_token = 0 if 'model' in config.completion_api_params: self.audit_mgr.report_text_model_usage(config.completion_api_params['model'], ai.get_total_tokens()) + current_round_token = ai.get_total_tokens() elif 'engine' in config.completion_api_params: self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'], response['usage']['total_tokens']) + current_round_token = response['usage']['total_tokens'] - return ai.get_message() + return ai.get_message(), current_round_token def request_image(self, prompt) -> dict: """请求图片接口回复 diff --git a/pkg/openai/session.py b/pkg/openai/session.py index 56b9b32..9c7afb8 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -72,6 +72,7 @@ def load_sessions(): temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] try: temp_session.prompt = json.loads(session_data[session_name]['prompt']) + temp_session.token_counts = json.loads(session_data[session_name]['token_counts']) except Exception: temp_session.prompt = reset_session_prompt(session_name, session_data[session_name]['prompt']) temp_session.persistence() @@ -106,6 +107,9 @@ class Session: prompt = [] """使用list来保存会话中的回合""" + token_counts = [] + """每个回合的token数量""" + default_prompt = [] """本session的默认prompt""" @@ -146,6 +150,8 @@ class Session: self.name = name self.create_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time()) + self.prompt = [] + self.token_counts = [] self.schedule() self.response_lock = threading.Lock() @@ -209,9 +215,16 @@ class Session: config = pkg.utils.context.get_config() max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 + prompts, counts = self.cut_out(text, max_length) + + # 计算请求前的prompt数量 + total_token_before_query = 0 + for token_count in counts: + total_token_before_query += token_count + # 向API请求补全 - message = pkg.utils.context.get_openai_manager().request_completion( - self.cut_out(text, max_length), + message, total_token = pkg.utils.context.get_openai_manager().request_completion( + prompts, ) # 成功获取,处理回复 @@ -228,6 +241,10 @@ class Session: self.prompt.append({'role': 'user', 'content': text}) self.prompt.append({'role': 'assistant', 'content': res_ans}) + # 向token_counts中添加本回合的token数量 + self.token_counts.append(total_token-total_token_before_query) + logging.debug("本回合使用token: {}, session counts: {}".format(total_token-total_token_before_query, self.token_counts)) + if self.just_switched_to_exist_session: self.just_switched_to_exist_session = False self.set_ongoing() @@ -244,39 +261,65 @@ class Session: question = self.prompt[-2]['content'] self.prompt = self.prompt[:-2] + self.token_counts = self.token_counts[:-1] # 返回上一回合的问题 return question # 构建对话体 - def cut_out(self, msg: str, max_tokens: int) -> list: - """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens""" - # 如果用户消息长度超过max_tokens,直接返回 - temp_prompt: list = [] - temp_prompt += self.default_prompt - temp_prompt.append( + def cut_out(self, msg: str, max_tokens: int) -> tuple[list, list]: + """将现有prompt进行切割处理,使得新的prompt长度不超过max_tokens + + :return: (新的prompt, 新的token_counts) + """ + + # 最终由三个部分组成 + # - default_prompt 情景预设固定值 + # - changable_prompts 可变部分, 此会话中的历史对话回合 + # - current_question 当前问题 + + # 包装目前的对话回合内容 + changable_prompts = [] + changable_counts = [] + # 倒着来, 遍历prompt的步长为2, 遍历tokens_counts的步长为1 + changable_index = len(self.prompt) - 1 + token_count_index = len(self.token_counts) - 1 + + packed_tokens = 0 + + print(self.prompt) + + while changable_index >= 0 and token_count_index >= 0: + if packed_tokens + self.token_counts[token_count_index] > max_tokens: + break + + changable_prompts.insert(0, self.prompt[changable_index]) + changable_prompts.insert(0, self.prompt[changable_index - 1]) + changable_counts.insert(0, self.token_counts[token_count_index]) + packed_tokens += self.token_counts[token_count_index] + + changable_index -= 2 + token_count_index -= 1 + + # 将default_prompt和changable_prompts合并 + result_prompt = self.default_prompt + changable_prompts + + print(changable_prompts) + + # 添加当前问题 + result_prompt.append( { 'role': 'user', 'content': msg } ) - token_count = 0 - for item in temp_prompt: - token_count += len(item['content']) + logging.debug('cut_out: {}\nchangable section tokens: {}\npacked counts: {}\nsession counts: {}'.format(json.dumps(result_prompt, ensure_ascii=False, indent=4), + packed_tokens, + changable_counts, + self.token_counts)) - # 倒序遍历prompt - for i in range(len(self.prompt) - 1, -1, -1): - if token_count >= max_tokens: - break - - # 将prompt加到temp_prompt倒数第二个位置 - temp_prompt.insert(len(self.default_prompt), self.prompt[i]) - token_count += len(self.prompt[i]['content']) - - logging.debug('cut_out: {}'.format(json.dumps(temp_prompt, ensure_ascii=False, indent=4))) - - return temp_prompt + return result_prompt, changable_counts # 持久化session def persistence(self): @@ -291,7 +334,7 @@ class Session: subject_number = int(name_spt[1]) db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - json.dumps(self.prompt), json.dumps(self.default_prompt)) + json.dumps(self.prompt), json.dumps(self.default_prompt), json.dumps(self.token_counts)) # 重置session def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): @@ -314,6 +357,7 @@ class Session: self.default_prompt = self.get_default_prompt(use_prompt) self.prompt = [] + self.token_counts = [] self.create_timestamp = int(time.time()) self.last_interact_timestamp = int(time.time()) self.just_switched_to_exist_session = False @@ -339,6 +383,7 @@ class Session: self.last_interact_timestamp = last_one['last_interact_timestamp'] try: self.prompt = json.loads(last_one['prompt']) + self.token_counts = json.loads(last_one['token_counts']) except json.decoder.JSONDecodeError: self.prompt = reset_session_prompt(self.name, last_one['prompt']) self.persistence() @@ -359,6 +404,7 @@ class Session: self.last_interact_timestamp = next_one['last_interact_timestamp'] try: self.prompt = json.loads(next_one['prompt']) + self.token_counts = json.loads(next_one['token_counts']) except json.decoder.JSONDecodeError: self.prompt = reset_session_prompt(self.name, next_one['prompt']) self.persistence() From cde168c93c5aea9abcb42a0c536964c0cc49921a Mon Sep 17 00:00:00 2001 From: Rock Chin <1010553892@qq.com> Date: Sun, 19 Mar 2023 08:32:34 +0000 Subject: [PATCH 4/4] =?UTF-8?q?doc:=20full=5Fscenario=E7=9A=84=E7=BC=96?= =?UTF-8?q?=E5=86=99=E6=95=99=E7=A8=8B=20(#301)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config-template.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/config-template.py b/config-template.py index e1558d6..e35e366 100644 --- a/config-template.py +++ b/config-template.py @@ -82,7 +82,30 @@ default_prompt = { # 情景预设格式 # 参考值:旧版本方式:default | 完整情景:full_scenario # 旧版本的格式为上述default_prompt中的内容,或prompts目录下的文件名 -# 完整情景预设的格式为JSON,在JSON文件中列出对话的每个回合,编写方法见scenario/default-template.json +# +# 完整情景预设的格式为JSON,在scenario目录下的JSON文件中列出对话的每个回合,编写方法见scenario/default-template.json +# 编写方法例如: +# { +# "prompt": [ +# { +# "role": "user", +# "content": "之后当我需要帮助时,请说“输入!help获取帮助”" +# },{ +# "role": "assistant", +# "content": "好的,当你之后需要帮助时,我会说“输入!help获取帮助”" +# },{ +# "role": "user", +# "content": "帮助" +# },{ +# "role": "assistant", +# "content": "输入!help获取帮助" +# } +# ] +# } +# +# 您可以按照上述格式编写自己的情景预设,在prompt中列出对话的每个回合, +# role为user或assistant,分别表示用户和机器人的回复 +# 每个JSON文件是一个情景预设,文件名即为情景预设的名称 preset_mode = "default" # 群内响应规则 @@ -139,7 +162,7 @@ encourage_sponsor_at_start = True # 每次向OpenAI接口发送对话记录上下文的字符数 # 最大不超过(4096 - max_tokens)个字符,max_tokens为下方completion_api_params中的max_tokens # 注意:较大的prompt_submit_length会导致OpenAI账户额度消耗更快 -prompt_submit_length = 1024 +prompt_submit_length = 2048 # OpenAI补全API的参数 # 请在下方填写模型,程序自动选择接口