mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
feat: 完善session维护代码
This commit is contained in:
parent
da3b964a8c
commit
9ca641a2e5
23
main.py
23
main.py
|
@ -1,5 +1,17 @@
|
|||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
import pkg.openai.manager
|
||||
import pkg.database.manager
|
||||
import pkg.openai.session
|
||||
|
||||
|
||||
def init_db():
|
||||
import config
|
||||
database = pkg.database.manager.DatabaseManager(**config.mysql_config)
|
||||
|
||||
database.initialize_database()
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -12,8 +24,17 @@ def main():
|
|||
assert os.path.exists('config.py')
|
||||
import config
|
||||
|
||||
# print(config.mirai_http_api_config)
|
||||
# 主启动流程
|
||||
openai_interact = pkg.openai.manager.OpenAIInteract(config.openai_config['api_key'], config.completion_api_params)
|
||||
|
||||
database = pkg.database.manager.DatabaseManager(**config.mysql_config)
|
||||
|
||||
# 加载所有未超时的session
|
||||
pkg.openai.session.load_sessions()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
|
||||
init_db()
|
||||
sys.exit(0)
|
||||
main()
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
import time
|
||||
|
||||
import pymysql
|
||||
|
||||
import config
|
||||
|
||||
inst = None
|
||||
|
||||
|
||||
|
@ -26,9 +30,70 @@ class DatabaseManager:
|
|||
|
||||
def reconnect(self):
|
||||
self.conn = pymysql.connect(host=self.host, port=self.port, user=self.user, password=self.password,
|
||||
database=self.database)
|
||||
database=self.database, autocommit=True)
|
||||
self.cursor = self.conn.cursor()
|
||||
|
||||
def initialize_database(self):
|
||||
self.cursor.execute("""
|
||||
create table if not exists `sessions` (
|
||||
`id` bigint not null auto_increment primary key,
|
||||
`name` varchar(255) not null,
|
||||
`type` varchar(255) not null,
|
||||
`number` bigint not null,
|
||||
`create_timestamp` bigint not null,
|
||||
`last_interact_timestamp` bigint not null,
|
||||
`prompt` text not null
|
||||
)
|
||||
""")
|
||||
print('Database initialized.')
|
||||
|
||||
def persistence_session(self, subject_type: str, subject_number: int, create_timestamp: int,
|
||||
last_interact_timestamp: int, prompt: str):
|
||||
# 检查是否已经有了此name和create_timestamp的session
|
||||
# 如果有,就更新prompt和last_interact_timestamp
|
||||
# 如果没有,就插入一条新的记录
|
||||
self.cursor.execute("""
|
||||
select count(*) from `sessions` where `type` = '{}' and `number` = {} and `create_timestamp` = {}
|
||||
""".format(subject_type, subject_number, create_timestamp))
|
||||
count = self.cursor.fetchone()[0]
|
||||
if count == 0:
|
||||
self.cursor.execute("""
|
||||
insert into `sessions` (`name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`)
|
||||
values ('{}', '{}', {}, {}, {}, '{}')
|
||||
""".format("{}_{}".format(subject_type, subject_number), subject_type, subject_number, create_timestamp,
|
||||
last_interact_timestamp, prompt))
|
||||
else:
|
||||
self.cursor.execute("""
|
||||
update `sessions` set `last_interact_timestamp` = {}, `prompt` = '{}'
|
||||
where `type` = '{}' and `number` = {} and `create_timestamp` = {}
|
||||
""".format(last_interact_timestamp, prompt, subject_type, subject_number, create_timestamp))
|
||||
|
||||
# 记载还没过期的session数据
|
||||
def load_valid_sessions(self) -> dict:
|
||||
# 从数据库中加载所有还没过期的session
|
||||
self.cursor.execute("""
|
||||
select `name`, `type`, `number`, `create_timestamp`, `last_interact_timestamp`, `prompt`
|
||||
from `sessions` where `last_interact_timestamp` > {}
|
||||
""".format(int(time.time()) - config.session_expire_time))
|
||||
results = self.cursor.fetchall()
|
||||
sessions = {}
|
||||
for result in results:
|
||||
session_name = result[0]
|
||||
subject_type = result[1]
|
||||
subject_number = result[2]
|
||||
create_timestamp = result[3]
|
||||
last_interact_timestamp = result[4]
|
||||
prompt = result[5]
|
||||
|
||||
sessions[session_name] = {
|
||||
'subject_type': subject_type,
|
||||
'subject_number': subject_number,
|
||||
'create_timestamp': create_timestamp,
|
||||
'last_interact_timestamp': last_interact_timestamp,
|
||||
'prompt': prompt
|
||||
}
|
||||
return sessions
|
||||
|
||||
|
||||
def get_inst() -> DatabaseManager:
|
||||
global inst
|
||||
|
|
|
@ -1,9 +1,41 @@
|
|||
import time
|
||||
|
||||
import pkg.openai.manager
|
||||
import pkg.database.manager
|
||||
|
||||
sessions = {}
|
||||
|
||||
|
||||
session = {}
|
||||
def load_sessions():
|
||||
global sessions
|
||||
|
||||
db_inst = pkg.database.manager.get_inst()
|
||||
|
||||
session_data = db_inst.load_valid_sessions()
|
||||
|
||||
for session_name in session_data:
|
||||
temp_session = Session(session_name)
|
||||
temp_session.name = session_name
|
||||
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
|
||||
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
||||
temp_session.prompt = session_data[session_name]['prompt']
|
||||
|
||||
sessions[session_name] = temp_session
|
||||
|
||||
|
||||
def get_session(session_name: str):
|
||||
global sessions
|
||||
if session_name not in sessions:
|
||||
sessions[session_name] = Session(session_name)
|
||||
return sessions[session_name]
|
||||
|
||||
|
||||
def dump_session(session_name: str):
|
||||
global sessions
|
||||
if session_name in sessions:
|
||||
assert isinstance(sessions[session_name], Session)
|
||||
sessions[session_name].persistence()
|
||||
del sessions[session_name]
|
||||
|
||||
|
||||
# 通用的OpenAI API交互session
|
||||
|
@ -23,17 +55,14 @@ class Session:
|
|||
self.name = name
|
||||
self.create_timestamp = int(time.time())
|
||||
|
||||
global session
|
||||
session[name] = self
|
||||
|
||||
# 请求回复
|
||||
# 这个函数是阻塞的
|
||||
def append(self, text: str) -> str:
|
||||
self.prompt += self.user_name + ':' + text + '\n'+self.bot_name+':'
|
||||
self.prompt += self.user_name + ':' + text + '\n' + self.bot_name + ':'
|
||||
self.last_interact_timestamp = int(time.time())
|
||||
|
||||
# 向API请求补全
|
||||
response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name+':')
|
||||
response = pkg.openai.manager.get_inst().request_completion(self.prompt, self.user_name + ':')
|
||||
|
||||
# print(response)
|
||||
# 处理回复
|
||||
|
@ -50,4 +79,12 @@ class Session:
|
|||
return res_ans
|
||||
|
||||
def persistence(self):
|
||||
pass
|
||||
db_inst = pkg.database.manager.get_inst()
|
||||
|
||||
name_spt = self.name.split('_')
|
||||
|
||||
subject_type = name_spt[0]
|
||||
subject_number = int(name_spt[1])
|
||||
|
||||
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
||||
self.prompt)
|
||||
|
|
Loading…
Reference in New Issue
Block a user