diff --git a/pkg/database/manager.py b/pkg/database/manager.py index a51ab50..90a6dbb 100644 --- a/pkg/database/manager.py +++ b/pkg/database/manager.py @@ -206,7 +206,7 @@ class DatabaseManager: } # 列出与某个对象的所有对话session - def list_history(self, session_name: str, capacity: int, page: int, replace: str = ""): + def list_history(self, session_name: str, capacity: int, page: int): self.execute(""" select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status` from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {} @@ -227,7 +227,7 @@ class DatabaseManager: 'subject_number': subject_number, 'create_timestamp': create_timestamp, 'last_interact_timestamp': last_interact_timestamp, - 'prompt': prompt if replace == "" else prompt.replace(replace, "") + 'prompt': json.loads(prompt) }) return sessions diff --git a/pkg/openai/manager.py b/pkg/openai/manager.py index c64d21d..7320d22 100644 --- a/pkg/openai/manager.py +++ b/pkg/openai/manager.py @@ -5,7 +5,7 @@ import openai import pkg.openai.keymgr import pkg.utils.context import pkg.audit.gatherer - +from pkg.openai.modelmgr import Model, ChatCompletionModel, OpenaiModel # 为其他模块提供与OpenAI交互的接口 class OpenAIInteract: @@ -32,24 +32,26 @@ class OpenAIInteract: pkg.utils.context.set_openai_manager(self) # 请求OpenAI Completion - def request_completion(self, prompt, stop): + def request_completion(self, messages): config = pkg.utils.context.get_config() - response = openai.Completion.create( - prompt=prompt, - stop=stop, + + ai:Model = OpenaiModel(config.completion_api_params['model'], 'user') + ai.request( + messages, **config.completion_api_params ) + response = ai.get_response() logging.debug("OpenAI response: %s", response) if 'model' in config.completion_api_params: self.audit_mgr.report_text_model_usage(config.completion_api_params['model'], - response['usage']['total_tokens']) + ai.get_total_tokens()) elif 'engine' in config.completion_api_params: self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'], response['usage']['total_tokens']) - return response + return ai.get_message() def request_image(self, prompt): diff --git a/pkg/openai/modelmgr.py b/pkg/openai/modelmgr.py index c106824..62540e6 100644 --- a/pkg/openai/modelmgr.py +++ b/pkg/openai/modelmgr.py @@ -1,4 +1,9 @@ # 提供与模型交互的抽象接口 +import openai, logging + +CHAT_COMPLETION_MODELS = { + 'gpt-3.5-turbo' +} COMPLETION_MODELS = { 'text-davinci-003' @@ -12,23 +17,80 @@ IMAGE_MODELS = { } +class Model(): -# ModelManager -# 由session包含 -class ModelMgr(object): + can_chat = False - using_completion_model = "" - using_edit_model = "" - using_image_model = "" + def __init__(self, model_name, user_name, request_fun): + self.model_name = model_name + self.user_name = user_name + self.request_fun = request_fun - def __init__(self): - pass + def request(self, **kwargs): + ret = self.request_fun(**kwargs) + self.ret = self.ret_handle(ret) + self.message = self.ret["choices"][0]["message"] - def get_using_completion_model(self): - return self.using_completion_model + def msg_handle(self, msg): + return msg + + def ret_handle(self, ret): + return ret + + def get_total_tokens(self): + return self.ret['usage']['total_tokens'] + + def get_message(self): + return self.message + + def get_response(self): + return self.ret - def get_using_edit_model(self): - return self.using_edit_model +class ChatCompletionModel(Model): + def __init__(self, model_name, user_name): + request_fun = openai.ChatCompletion.create + self.can_chat = True + super().__init__(model_name, user_name, request_fun) - def get_using_image_model(self): - return self.using_image_model + def request(self, messages, **kwargs): + ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name) + self.ret = self.ret_handle(ret) + self.message = self.ret["choices"][0]["message"]['content'] + + def get_content(self): + return self.message + +class CompletionModel(Model): + def __init__(self, model_name, user_name): + request_fun = openai.Completion.create + super().__init__(model_name, user_name, request_fun) + + def request(self, prompt, **kwargs): + ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs) + self.ret = self.ret_handle(ret) + self.message = self.ret["choices"][0]["text"] + + def msg_handle(self, msgs): + prompt = '' + for msg in msgs: + if msg['role'] == '': + prompt = prompt + "{}\n".format(msg['content']) + else: + prompt = prompt + "{}:{}\n".format(msg['role'] if msg['role']!='system' else '你的回答要遵守此规则', msg['content']) + print(prompt) + return prompt + + def get_text(self): + return self.message + +def OpenaiModel(model_name:str, user_name='user'): + if model_name in CHAT_COMPLETION_MODELS: + model = ChatCompletionModel(model_name, user_name) + elif model_name in COMPLETION_MODELS: + model = CompletionModel(model_name, user_name) + else : + log = "找不到模型[{}],请检查配置文件".format(model_name) + logging.error(log) + raise IndexError(log) + + return model diff --git a/pkg/openai/session.py b/pkg/openai/session.py index c04abc4..a534f01 100644 --- a/pkg/openai/session.py +++ b/pkg/openai/session.py @@ -1,8 +1,10 @@ import logging import threading import time +import json import pkg.openai.manager +import pkg.openai.modelmgr import pkg.database.manager import pkg.utils.context @@ -33,7 +35,7 @@ def load_sessions(): temp_session.name = session_name temp_session.create_timestamp = session_data[session_name]['create_timestamp'] temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp'] - temp_session.prompt = session_data[session_name]['prompt'] + temp_session.prompt = json.loads(session_data[session_name]['prompt']) sessions[session_name] = temp_session @@ -60,13 +62,10 @@ def dump_session(session_name: str): class Session: name = '' - prompt = "" + prompt = {} import config - user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' - bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' - create_timestamp = 0 last_interact_timestamp = 0 @@ -99,11 +98,10 @@ class Session: else: current_default_prompt = dprompt.get_prompt(use_default) - user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You' - bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot' - - return (user_name + ":{}\n".format(current_default_prompt) + bot_name + ":好的\n") \ - if current_default_prompt != '' else '' + return [{ + 'role': 'system', + 'content': current_default_prompt + }] def __init__(self, name: str): self.name = name @@ -165,22 +163,17 @@ class Session: if event.is_prevented_default(): return None - # max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7 config = pkg.utils.context.get_config() - max_rounds = 1000 # 不再限制回合数 max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024 # 向API请求补全 - response = pkg.utils.context.get_openai_manager().request_completion( - self.cut_out(self.prompt + self.user_name + ':' + - text + '\n' + self.bot_name + ':', - max_rounds, max_length), - self.user_name + ':') + self.cut_out(text, max_length) + message = pkg.utils.context.get_openai_manager().request_completion( + self.prompt + ) - self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':' - # print(response) # 处理回复 - res_test = response["choices"][0]["text"] + res_test = message res_ans = res_test # 去除开头可能的提示 @@ -189,50 +182,44 @@ class Session: del (res_ans_spt[0]) res_ans = '\n\n'.join(res_ans_spt) - self.prompt += "{}".format(res_ans) + '\n' + if config.completion_api_params['model'] in pkg.openai.modelmgr.CHAT_COMPLETION_MODELS: + self.prompt.append({'role':'assistant', 'content':res_ans}) + elif config.completion_api_params['model'] in pkg.openai.modelmgr.COMPLETION_MODELS: + self.prompt.append({'role':'', 'content':res_ans}) if self.just_switched_to_exist_session: self.just_switched_to_exist_session = False self.set_ongoing() - return res_ans + return res_ans if res_ans[0]!='\n' else res_ans[1:] # 删除上一回合并返回上一回合的问题 def undo(self) -> str: self.last_interact_timestamp = int(time.time()) # 删除上一回合 - to_delete = self.cut_out(self.prompt, 1, 1024) - - self.prompt = self.prompt.replace(to_delete, '') + if self.prompt[-1]['role'] != 'user': + res = self.prompt[-1]['content'] + self.prompt.remove(self.prompt[-2]) + else: + res = self.prompt[-2]['content'] + self.prompt.remove(self.prompt[-1]) # 返回上一回合的问题 - return to_delete.split(self.bot_name + ':')[0].split(self.user_name + ':')[1].strip() + return res - # 从尾部截取prompt里不多于max_rounds个回合,长度不大于max_tokens的字符串 - # 保证都是完整的对话 - def cut_out(self, prompt: str, max_rounds: int, max_tokens: int) -> str: - # 分隔出每个回合 - rounds_spt_by_user_name = prompt.split(self.user_name + ':') + # 构建对话体 + def cut_out(self, msg: str, max_tokens: int) -> str: - result = '' + if len(msg) > max_tokens: + msg = msg[:max_tokens] - checked_rounds = 0 - # 从后往前遍历,加到result前面,检查result是否符合要求 - for i in range(len(rounds_spt_by_user_name) - 1, 0, -1): - result_temp = self.user_name + ':' + rounds_spt_by_user_name[i] + result - checked_rounds += 1 + self.prompt.append({ + 'role': 'user', + 'content': msg + }) - if checked_rounds > max_rounds: - break - - if int((len(result_temp.encode('utf-8')) - len(result_temp)) / 2 + len(result_temp)) > max_tokens: - break - - result = result_temp - - logging.debug('cut_out: {}'.format(result)) - return result + logging.debug('cut_out: {}'.format(msg)) # 持久化session def persistence(self): @@ -247,11 +234,11 @@ class Session: subject_number = int(name_spt[1]) db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp, - self.prompt) + json.dumps(self.prompt)) # 重置session def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None): - if not self.prompt.endswith(':好的\n'): + if self.prompt[-1]['role'] != "system": self.persistence() if explicit: # 触发插件事件 @@ -291,7 +278,7 @@ class Session: self.create_timestamp = last_one['create_timestamp'] self.last_interact_timestamp = last_one['last_interact_timestamp'] - self.prompt = last_one['prompt'] + self.prompt = json.loads(last_one['prompt']) self.just_switched_to_exist_session = True return self @@ -306,14 +293,13 @@ class Session: self.create_timestamp = next_one['create_timestamp'] self.last_interact_timestamp = next_one['last_interact_timestamp'] - self.prompt = next_one['prompt'] + self.prompt = json.loads(next_one['prompt']) self.just_switched_to_exist_session = True return self def list_history(self, capacity: int = 10, page: int = 0): - return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page, - self.get_default_prompt()) + return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page) def draw_image(self, prompt: str): return pkg.utils.context.get_openai_manager().request_image(prompt) diff --git a/pkg/qqbot/command.py b/pkg/qqbot/command.py index 7de7e5a..e2b34a6 100644 --- a/pkg/qqbot/command.py +++ b/pkg/qqbot/command.py @@ -185,11 +185,7 @@ def process_command(session_name: str, text_message: str, mgr, config, else: datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format( - datetime_str) + result.prompt[ - :min(100, - len(result.prompt))] + \ - ("..." if len(result.prompt) > 100 else "#END#")] + reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)] elif cmd == 'next': result = pkg.openai.session.get_session(session_name).next_session() if result is None: @@ -197,13 +193,18 @@ def process_command(session_name: str, text_message: str, mgr, config, else: datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime( '%Y-%m-%d %H:%M:%S') - reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format( - datetime_str) + result.prompt[ - :min(100, - len(result.prompt))] + \ - ("..." if len(result.prompt) > 100 else "#END#")] + reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)] elif cmd == 'prompt': - reply = ["[bot]当前对话所有内容:\n" + pkg.openai.session.get_session(session_name).prompt] + msgs = "" + session:list = pkg.openai.session.get_session(session_name).prompt + for msg in session: + if len(params) != 0 and params[0] in ['-all', '-a']: + msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content']) + elif len(msg['content']) > 30: + msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30]) + else: + msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content']) + reply = ["[bot]当前对话所有内容:\n{}".format(msgs)] elif cmd == 'list': pkg.openai.session.get_session(session_name).persistence() page = 0 @@ -225,8 +226,7 @@ def process_command(session_name: str, text_message: str, mgr, config, datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp']) reply_str += "#{} 创建:{} {}\n".format(i + page * 10, datetime_obj.strftime("%Y-%m-%d %H:%M:%S"), - results[i]['prompt'][ - :min(20, len(results[i]['prompt']))]) + results[i]['prompt'][1]['content']) if results[i]['create_timestamp'] == pkg.openai.session.get_session( session_name).create_timestamp: current = i + page * 10 diff --git a/requirements.txt b/requirements.txt index 7d44575..838279e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ requests~=2.28.1 -openai~=0.26.5 +openai~=0.27.0 pip~=22.3.1 dulwich~=0.21.3 colorlog~=6.6.0