重构了模型抽象,用来更好的支持gpt-3.5-turbo

This commit is contained in:
LINSTCL 2023-03-02 15:31:12 +08:00
parent 6f5802551f
commit fd25d61b56
6 changed files with 140 additions and 90 deletions

View File

@ -206,7 +206,7 @@ class DatabaseManager:
}
# 列出与某个对象的所有对话session
def list_history(self, session_name: str, capacity: int, page: int, replace: str = ""):
def list_history(self, session_name: str, capacity: int, page: int):
self.execute("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`, `status`
from `sessions` where `name` = '{}' order by `last_interact_timestamp` desc limit {} offset {}
@ -227,7 +227,7 @@ class DatabaseManager:
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt if replace == "" else prompt.replace(replace, "")
'prompt': json.loads(prompt)
})
return sessions

View File

@ -5,7 +5,7 @@ import openai
import pkg.openai.keymgr
import pkg.utils.context
import pkg.audit.gatherer
from pkg.openai.modelmgr import Model, ChatCompletionModel, OpenaiModel
# 为其他模块提供与OpenAI交互的接口
class OpenAIInteract:
@ -32,24 +32,26 @@ class OpenAIInteract:
pkg.utils.context.set_openai_manager(self)
# 请求OpenAI Completion
def request_completion(self, prompt, stop):
def request_completion(self, messages):
config = pkg.utils.context.get_config()
response = openai.Completion.create(
prompt=prompt,
stop=stop,
ai:Model = OpenaiModel(config.completion_api_params['model'], 'user')
ai.request(
messages,
**config.completion_api_params
)
response = ai.get_response()
logging.debug("OpenAI response: %s", response)
if 'model' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['model'],
response['usage']['total_tokens'])
ai.get_total_tokens())
elif 'engine' in config.completion_api_params:
self.audit_mgr.report_text_model_usage(config.completion_api_params['engine'],
response['usage']['total_tokens'])
return response
return ai.get_message()
def request_image(self, prompt):

View File

@ -1,4 +1,9 @@
# 提供与模型交互的抽象接口
import openai, logging
CHAT_COMPLETION_MODELS = {
'gpt-3.5-turbo'
}
COMPLETION_MODELS = {
'text-davinci-003'
@ -12,23 +17,80 @@ IMAGE_MODELS = {
}
class Model():
# ModelManager
# 由session包含
class ModelMgr(object):
can_chat = False
using_completion_model = ""
using_edit_model = ""
using_image_model = ""
def __init__(self, model_name, user_name, request_fun):
self.model_name = model_name
self.user_name = user_name
self.request_fun = request_fun
def __init__(self):
pass
def request(self, **kwargs):
ret = self.request_fun(**kwargs)
self.ret = self.ret_handle(ret)
self.message = self.ret["choices"][0]["message"]
def get_using_completion_model(self):
return self.using_completion_model
def msg_handle(self, msg):
return msg
def ret_handle(self, ret):
return ret
def get_total_tokens(self):
return self.ret['usage']['total_tokens']
def get_message(self):
return self.message
def get_response(self):
return self.ret
def get_using_edit_model(self):
return self.using_edit_model
class ChatCompletionModel(Model):
def __init__(self, model_name, user_name):
request_fun = openai.ChatCompletion.create
self.can_chat = True
super().__init__(model_name, user_name, request_fun)
def get_using_image_model(self):
return self.using_image_model
def request(self, messages, **kwargs):
ret = self.request_fun(messages = self.msg_handle(messages), **kwargs, user=self.user_name)
self.ret = self.ret_handle(ret)
self.message = self.ret["choices"][0]["message"]['content']
def get_content(self):
return self.message
class CompletionModel(Model):
def __init__(self, model_name, user_name):
request_fun = openai.Completion.create
super().__init__(model_name, user_name, request_fun)
def request(self, prompt, **kwargs):
ret = self.request_fun(prompt = self.msg_handle(prompt), **kwargs)
self.ret = self.ret_handle(ret)
self.message = self.ret["choices"][0]["text"]
def msg_handle(self, msgs):
prompt = ''
for msg in msgs:
if msg['role'] == '':
prompt = prompt + "{}\n".format(msg['content'])
else:
prompt = prompt + "{}:{}\n".format(msg['role'] if msg['role']!='system' else '你的回答要遵守此规则', msg['content'])
print(prompt)
return prompt
def get_text(self):
return self.message
def OpenaiModel(model_name:str, user_name='user'):
if model_name in CHAT_COMPLETION_MODELS:
model = ChatCompletionModel(model_name, user_name)
elif model_name in COMPLETION_MODELS:
model = CompletionModel(model_name, user_name)
else :
log = "找不到模型[{}],请检查配置文件".format(model_name)
logging.error(log)
raise IndexError(log)
return model

View File

@ -1,8 +1,10 @@
import logging
import threading
import time
import json
import pkg.openai.manager
import pkg.openai.modelmgr
import pkg.database.manager
import pkg.utils.context
@ -33,7 +35,7 @@ def load_sessions():
temp_session.name = session_name
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
temp_session.prompt = session_data[session_name]['prompt']
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
sessions[session_name] = temp_session
@ -60,13 +62,10 @@ def dump_session(session_name: str):
class Session:
name = ''
prompt = ""
prompt = {}
import config
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
create_timestamp = 0
last_interact_timestamp = 0
@ -99,11 +98,10 @@ class Session:
else:
current_default_prompt = dprompt.get_prompt(use_default)
user_name = config.user_name if hasattr(config, 'user_name') and config.user_name != '' else 'You'
bot_name = config.bot_name if hasattr(config, 'bot_name') and config.bot_name != '' else 'Bot'
return (user_name + ":{}\n".format(current_default_prompt) + bot_name + ":好的\n") \
if current_default_prompt != '' else ''
return [{
'role': 'system',
'content': current_default_prompt
}]
def __init__(self, name: str):
self.name = name
@ -165,22 +163,17 @@ class Session:
if event.is_prevented_default():
return None
# max_rounds = config.prompt_submit_round_amount if hasattr(config, 'prompt_submit_round_amount') else 7
config = pkg.utils.context.get_config()
max_rounds = 1000 # 不再限制回合数
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
# 向API请求补全
response = pkg.utils.context.get_openai_manager().request_completion(
self.cut_out(self.prompt + self.user_name + ':' +
text + '\n' + self.bot_name + ':',
max_rounds, max_length),
self.user_name + ':')
self.cut_out(text, max_length)
message = pkg.utils.context.get_openai_manager().request_completion(
self.prompt
)
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
# print(response)
# 处理回复
res_test = response["choices"][0]["text"]
res_test = message
res_ans = res_test
# 去除开头可能的提示
@ -189,50 +182,44 @@ class Session:
del (res_ans_spt[0])
res_ans = '\n\n'.join(res_ans_spt)
self.prompt += "{}".format(res_ans) + '\n'
if config.completion_api_params['model'] in pkg.openai.modelmgr.CHAT_COMPLETION_MODELS:
self.prompt.append({'role':'assistant', 'content':res_ans})
elif config.completion_api_params['model'] in pkg.openai.modelmgr.COMPLETION_MODELS:
self.prompt.append({'role':'', 'content':res_ans})
if self.just_switched_to_exist_session:
self.just_switched_to_exist_session = False
self.set_ongoing()
return res_ans
return res_ans if res_ans[0]!='\n' else res_ans[1:]
# 删除上一回合并返回上一回合的问题
def undo(self) -> str:
self.last_interact_timestamp = int(time.time())
# 删除上一回合
to_delete = self.cut_out(self.prompt, 1, 1024)
self.prompt = self.prompt.replace(to_delete, '')
if self.prompt[-1]['role'] != 'user':
res = self.prompt[-1]['content']
self.prompt.remove(self.prompt[-2])
else:
res = self.prompt[-2]['content']
self.prompt.remove(self.prompt[-1])
# 返回上一回合的问题
return to_delete.split(self.bot_name + ':')[0].split(self.user_name + ':')[1].strip()
return res
# 从尾部截取prompt里不多于max_rounds个回合长度不大于max_tokens的字符串
# 保证都是完整的对话
def cut_out(self, prompt: str, max_rounds: int, max_tokens: int) -> str:
# 分隔出每个回合
rounds_spt_by_user_name = prompt.split(self.user_name + ':')
# 构建对话体
def cut_out(self, msg: str, max_tokens: int) -> str:
result = ''
if len(msg) > max_tokens:
msg = msg[:max_tokens]
checked_rounds = 0
# 从后往前遍历加到result前面检查result是否符合要求
for i in range(len(rounds_spt_by_user_name) - 1, 0, -1):
result_temp = self.user_name + ':' + rounds_spt_by_user_name[i] + result
checked_rounds += 1
self.prompt.append({
'role': 'user',
'content': msg
})
if checked_rounds > max_rounds:
break
if int((len(result_temp.encode('utf-8')) - len(result_temp)) / 2 + len(result_temp)) > max_tokens:
break
result = result_temp
logging.debug('cut_out: {}'.format(result))
return result
logging.debug('cut_out: {}'.format(msg))
# 持久化session
def persistence(self):
@ -247,11 +234,11 @@ class Session:
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
self.prompt)
json.dumps(self.prompt))
# 重置session
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
if not self.prompt.endswith(':好的\n'):
if self.prompt[-1]['role'] != "system":
self.persistence()
if explicit:
# 触发插件事件
@ -291,7 +278,7 @@ class Session:
self.create_timestamp = last_one['create_timestamp']
self.last_interact_timestamp = last_one['last_interact_timestamp']
self.prompt = last_one['prompt']
self.prompt = json.loads(last_one['prompt'])
self.just_switched_to_exist_session = True
return self
@ -306,14 +293,13 @@ class Session:
self.create_timestamp = next_one['create_timestamp']
self.last_interact_timestamp = next_one['last_interact_timestamp']
self.prompt = next_one['prompt']
self.prompt = json.loads(next_one['prompt'])
self.just_switched_to_exist_session = True
return self
def list_history(self, capacity: int = 10, page: int = 0):
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page,
self.get_default_prompt())
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
def draw_image(self, prompt: str):
return pkg.utils.context.get_openai_manager().request_image(prompt)

View File

@ -185,11 +185,7 @@ def process_command(session_name: str, text_message: str, mgr, config,
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(
datetime_str) + result.prompt[
:min(100,
len(result.prompt))] + \
("..." if len(result.prompt) > 100 else "#END#")]
reply = ["[bot]已切换到前一次的对话:\n创建时间:{}\n".format(datetime_str)]
elif cmd == 'next':
result = pkg.openai.session.get_session(session_name).next_session()
if result is None:
@ -197,13 +193,18 @@ def process_command(session_name: str, text_message: str, mgr, config,
else:
datetime_str = datetime.datetime.fromtimestamp(result.create_timestamp).strftime(
'%Y-%m-%d %H:%M:%S')
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(
datetime_str) + result.prompt[
:min(100,
len(result.prompt))] + \
("..." if len(result.prompt) > 100 else "#END#")]
reply = ["[bot]已切换到后一次的对话:\n创建时间:{}\n".format(datetime_str)]
elif cmd == 'prompt':
reply = ["[bot]当前对话所有内容:\n" + pkg.openai.session.get_session(session_name).prompt]
msgs = ""
session:list = pkg.openai.session.get_session(session_name).prompt
for msg in session:
if len(params) != 0 and params[0] in ['-all', '-a']:
msgs = msgs + "{}: {}\n\n".format(msg['role'], msg['content'])
elif len(msg['content']) > 30:
msgs = msgs + "[{}]: {}...\n\n".format(msg['role'], msg['content'][:30])
else:
msgs = msgs + "[{}]: {}\n\n".format(msg['role'], msg['content'])
reply = ["[bot]当前对话所有内容:\n{}".format(msgs)]
elif cmd == 'list':
pkg.openai.session.get_session(session_name).persistence()
page = 0
@ -225,8 +226,7 @@ def process_command(session_name: str, text_message: str, mgr, config,
datetime_obj = datetime.datetime.fromtimestamp(results[i]['create_timestamp'])
reply_str += "#{} 创建:{} {}\n".format(i + page * 10,
datetime_obj.strftime("%Y-%m-%d %H:%M:%S"),
results[i]['prompt'][
:min(20, len(results[i]['prompt']))])
results[i]['prompt'][1]['content'])
if results[i]['create_timestamp'] == pkg.openai.session.get_session(
session_name).create_timestamp:
current = i + page * 10

View File

@ -1,5 +1,5 @@
requests~=2.28.1
openai~=0.26.5
openai~=0.27.0
pip~=22.3.1
dulwich~=0.21.3
colorlog~=6.6.0