feat: 完善session维护代码

This commit is contained in:
Rock Chin 2022-12-08 00:41:35 +08:00
parent da3b964a8c
commit 9ca641a2e5
3 changed files with 132 additions and 9 deletions

23
main.py
View File

@ -1,5 +1,17 @@
import os
import shutil
import sys
import pkg.openai.manager
import pkg.database.manager
import pkg.openai.session
def init_db():
import config
database = pkg.database.manager.DatabaseManager(**config.mysql_config)
database.initialize_database()
def main():
@ -12,8 +24,17 @@ def main():
assert os.path.exists('config.py')
import config
# print(config.mirai_http_api_config)
# 主启动流程
openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params)
database = pkg.database.manager.DatabaseManager(**config.mysql_config)
# 加载所有未超时的session
pkg.openai.session.load_sessions()
if __name__ == '__main__':
if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
init_db()
sys.exit(0)
main()

View File

@ -1,5 +1,9 @@
import time
import pymysql
import config
inst = None
@ -26,9 +30,70 @@ class DatabaseManager:
def reconnect(self):
self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password,
database=self.database)
database=self.database, autocommit=True)
self.cursor = self.conn.cursor()
def initialize_database(self):
self.cursor.execute("""
create table if not exists `sessions` (
`id` bigint not null auto_increment primary key,
`name` varchar(255) not null,
`type` varchar(255) not null,
`number` bigint not null,
`create_timestamp` bigint not null,
`last_interact_timestamp` bigint not null,
`prompt` text not null
)
""")
print('Database initialized.')
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
last_interact_timestamp: int, prompt: str):
# 检查是否已经有了此name和create_timestamp的session
# 如果有就更新prompt和last_interact_timestamp
# 如果没有,就插入一条新的记录
self.cursor.execute("""
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(subject_type, subject_number, create_timestamp))
count = self.cursor.fetchone()[0]
if count == 0:
self.cursor.execute("""
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
values ('{}', '{}', {}, {}, {}, '{}')
""".format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
last_interact_timestamp, prompt))
else:
self.cursor.execute("""
update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}'
where `type` = '{}' and `number` = {} and `create_timestamp` = {}
""".format(last_interact_timestamp, prompt, subject_type, subject_number, create_timestamp))
# 记载还没过期的session数据
def load_valid_sessions(self) -> dict:
# 从数据库中加载所有还没过期的session
self.cursor.execute("""
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`
from `sessions` where `last_interact_timestamp` > {}
""".format(int(time.time()) - config.session_expire_time))
results = self.cursor.fetchall()
sessions = {}
for result in results:
session_name = result[0]
subject_type = result[1]
subject_number = result[2]
create_timestamp = result[3]
last_interact_timestamp = result[4]
prompt = result[5]
sessions[session_name] = {
'subject_type': subject_type,
'subject_number': subject_number,
'create_timestamp': create_timestamp,
'last_interact_timestamp': last_interact_timestamp,
'prompt': prompt
}
return sessions
def get_inst() -> DatabaseManager:
global inst

View File

@ -1,9 +1,41 @@
import time
import pkg.openai.manager
import pkg.database.manager
sessions = {}
session = {}
def load_sessions():
global sessions
db_inst = pkg.database.manager.get_inst()
session_data = db_inst.load_valid_sessions()
for session_name in session_data:
temp_session = Session(session_name)
temp_session.name = session_name
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
temp_session.prompt = session_data[session_name]['prompt']
sessions[session_name] = temp_session
def get_session(session_name: str):
global sessions
if session_name not in sessions:
sessions[session_name] = Session(session_name)
return sessions[session_name]
def dump_session(session_name: str):
global sessions
if session_name in sessions:
assert isinstance(sessions[session_name], Session)
sessions[session_name].persistence()
del sessions[session_name]
# 通用的OpenAI API交互session
@ -23,17 +55,14 @@ class Session:
self.name = name
self.create_timestamp = int(time.time())
global session
session[name] = self
# 请求回复
# 这个函数是阻塞的
def append(self, text: str) -> str:
self.prompt += self.user_name + ':' + text + '\n'+self.bot_name+':'
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
self.last_interact_timestamp = int(time.time())
# 向API请求补全
response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name+':')
response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name + ':')
# print(response)
# 处理回复
@ -50,4 +79,12 @@ class Session:
return res_ans
def persistence(self):
pass
db_inst = pkg.database.manager.get_inst()
name_spt = self.name.split('_')
subject_type = name_spt[0]
subject_number = int(name_spt[1])
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
self.prompt)