2022-12-09 00:01:49 +08:00
|
|
|
|
import logging
|
2022-12-09 16:17:50 +08:00
|
|
|
|
import threading
|
2022-12-07 22:27:05 +08:00
|
|
|
|
import time
|
2023-03-02 15:31:12 +08:00
|
|
|
|
import json
|
2022-12-07 22:27:05 +08:00
|
|
|
|
|
|
|
|
|
import pkg.openai.manager
|
2023-03-02 15:31:12 +08:00
|
|
|
|
import pkg.openai.modelmgr
|
2022-12-08 00:41:35 +08:00
|
|
|
|
import pkg.database.manager
|
2023-01-01 23:18:32 +08:00
|
|
|
|
import pkg.utils.context
|
2022-12-07 22:27:05 +08:00
|
|
|
|
|
2023-01-14 22:36:48 +08:00
|
|
|
|
import pkg.plugin.host as plugin_host
|
|
|
|
|
import pkg.plugin.models as plugin_models
|
|
|
|
|
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# 运行时保存的所有session
|
2022-12-08 00:41:35 +08:00
|
|
|
|
sessions = {}
|
2022-12-07 22:27:05 +08:00
|
|
|
|
|
2022-12-08 00:41:35 +08:00
|
|
|
|
|
2022-12-08 13:22:54 +08:00
|
|
|
|
class SessionOfflineStatus:
|
|
|
|
|
ON_GOING = 'on_going'
|
|
|
|
|
EXPLICITLY_CLOSED = 'explicitly_closed'
|
|
|
|
|
|
|
|
|
|
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# 从数据加载session
|
2022-12-08 00:41:35 +08:00
|
|
|
|
def load_sessions():
|
|
|
|
|
global sessions
|
|
|
|
|
|
2023-01-01 23:18:32 +08:00
|
|
|
|
db_inst = pkg.utils.context.get_database_manager()
|
2022-12-08 00:41:35 +08:00
|
|
|
|
|
|
|
|
|
session_data = db_inst.load_valid_sessions()
|
|
|
|
|
|
|
|
|
|
for session_name in session_data:
|
2022-12-09 00:01:49 +08:00
|
|
|
|
logging.info('加载session: {}'.format(session_name))
|
|
|
|
|
|
2022-12-08 00:41:35 +08:00
|
|
|
|
temp_session = Session(session_name)
|
|
|
|
|
temp_session.name = session_name
|
|
|
|
|
temp_session.create_timestamp = session_data[session_name]['create_timestamp']
|
|
|
|
|
temp_session.last_interact_timestamp = session_data[session_name]['last_interact_timestamp']
|
2023-03-02 15:31:12 +08:00
|
|
|
|
temp_session.prompt = json.loads(session_data[session_name]['prompt'])
|
2022-12-08 00:41:35 +08:00
|
|
|
|
|
|
|
|
|
sessions[session_name] = temp_session
|
|
|
|
|
|
|
|
|
|
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# 获取指定名称的session,如果不存在则创建一个新的
|
2022-12-08 00:41:35 +08:00
|
|
|
|
def get_session(session_name: str):
|
|
|
|
|
global sessions
|
|
|
|
|
if session_name not in sessions:
|
|
|
|
|
sessions[session_name] = Session(session_name)
|
|
|
|
|
return sessions[session_name]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dump_session(session_name: str):
|
|
|
|
|
global sessions
|
|
|
|
|
if session_name in sessions:
|
|
|
|
|
assert isinstance(sessions[session_name], Session)
|
|
|
|
|
sessions[session_name].persistence()
|
|
|
|
|
del sessions[session_name]
|
2022-12-07 22:27:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 通用的OpenAI API交互session
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# session内部保留了对话的上下文,
|
|
|
|
|
# 收到用户消息后,将上下文提交给OpenAI API生成回复
|
2022-12-07 22:27:05 +08:00
|
|
|
|
class Session:
|
|
|
|
|
name = ''
|
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
prompt = {}
|
2022-12-07 22:27:05 +08:00
|
|
|
|
|
2023-01-04 17:09:57 +08:00
|
|
|
|
import config
|
|
|
|
|
|
2022-12-07 22:27:05 +08:00
|
|
|
|
create_timestamp = 0
|
|
|
|
|
|
|
|
|
|
last_interact_timestamp = 0
|
|
|
|
|
|
2022-12-08 21:58:02 +08:00
|
|
|
|
just_switched_to_exist_session = False
|
|
|
|
|
|
2022-12-13 13:36:16 +08:00
|
|
|
|
response_lock = None
|
2022-12-12 22:04:38 +08:00
|
|
|
|
|
|
|
|
|
# 加锁
|
|
|
|
|
def acquire_response_lock(self):
|
2022-12-13 00:14:09 +08:00
|
|
|
|
logging.debug('{},lock acquire,{}'.format(self.name, self.response_lock))
|
2022-12-12 22:04:38 +08:00
|
|
|
|
self.response_lock.acquire()
|
2022-12-13 13:36:16 +08:00
|
|
|
|
logging.debug('{},lock acquire successfully,{}'.format(self.name, self.response_lock))
|
2022-12-12 22:04:38 +08:00
|
|
|
|
|
|
|
|
|
# 释放锁
|
|
|
|
|
def release_response_lock(self):
|
2022-12-13 16:04:51 +08:00
|
|
|
|
if self.response_lock.locked():
|
|
|
|
|
logging.debug('{},lock release,{}'.format(self.name, self.response_lock))
|
|
|
|
|
self.response_lock.release()
|
|
|
|
|
logging.debug('{},lock release successfully,{}'.format(self.name, self.response_lock))
|
2022-12-12 22:04:38 +08:00
|
|
|
|
|
2023-01-04 21:46:01 +08:00
|
|
|
|
# 从配置文件获取会话预设信息
|
2023-02-19 11:46:12 +08:00
|
|
|
|
def get_default_prompt(self, use_default: str=None):
|
2023-01-04 21:46:01 +08:00
|
|
|
|
config = pkg.utils.context.get_config()
|
2023-02-19 11:46:12 +08:00
|
|
|
|
|
|
|
|
|
import pkg.openai.dprompt as dprompt
|
|
|
|
|
|
|
|
|
|
if use_default is None:
|
|
|
|
|
current_default_prompt = dprompt.get_prompt(dprompt.get_current())
|
|
|
|
|
else:
|
|
|
|
|
current_default_prompt = dprompt.get_prompt(use_default)
|
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
return [{
|
|
|
|
|
'role': 'system',
|
|
|
|
|
'content': current_default_prompt
|
|
|
|
|
}]
|
2023-01-04 21:46:01 +08:00
|
|
|
|
|
2022-12-07 22:27:05 +08:00
|
|
|
|
def __init__(self, name: str):
|
|
|
|
|
self.name = name
|
|
|
|
|
self.create_timestamp = int(time.time())
|
2022-12-08 14:28:46 +08:00
|
|
|
|
self.last_interact_timestamp = int(time.time())
|
2022-12-09 16:17:50 +08:00
|
|
|
|
self.schedule()
|
|
|
|
|
|
2022-12-13 16:04:51 +08:00
|
|
|
|
self.response_lock = threading.Lock()
|
2023-01-04 21:46:01 +08:00
|
|
|
|
self.prompt = self.get_default_prompt()
|
2022-12-13 13:36:16 +08:00
|
|
|
|
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# 设定检查session最后一次对话是否超过过期时间的计时器
|
2022-12-09 16:17:50 +08:00
|
|
|
|
def schedule(self):
|
|
|
|
|
threading.Thread(target=self.expire_check_timer_loop, args=(self.create_timestamp,)).start()
|
|
|
|
|
|
|
|
|
|
# 检查session是否已经过期
|
|
|
|
|
def expire_check_timer_loop(self, create_timestamp: int):
|
2022-12-13 13:36:16 +08:00
|
|
|
|
global sessions
|
2022-12-09 16:17:50 +08:00
|
|
|
|
while True:
|
|
|
|
|
time.sleep(60)
|
|
|
|
|
|
|
|
|
|
# 不是此session已更换,退出
|
2022-12-13 13:36:16 +08:00
|
|
|
|
if self.create_timestamp != create_timestamp or self not in sessions.values():
|
2022-12-09 16:17:50 +08:00
|
|
|
|
return
|
2023-01-04 17:09:57 +08:00
|
|
|
|
|
|
|
|
|
config = pkg.utils.context.get_config()
|
2022-12-09 16:17:50 +08:00
|
|
|
|
if int(time.time()) - self.last_interact_timestamp > config.session_expire_time:
|
|
|
|
|
logging.info('session {} 已过期'.format(self.name))
|
2023-01-14 22:36:48 +08:00
|
|
|
|
|
|
|
|
|
# 触发插件事件
|
|
|
|
|
args = {
|
|
|
|
|
'session_name': self.name,
|
|
|
|
|
'session': self,
|
|
|
|
|
'session_expire_time': config.session_expire_time
|
|
|
|
|
}
|
|
|
|
|
event = pkg.plugin.host.emit(plugin_models.SessionExpired, **args)
|
|
|
|
|
if event.is_prevented_default():
|
|
|
|
|
return
|
|
|
|
|
|
2022-12-09 16:17:50 +08:00
|
|
|
|
self.reset(expired=True, schedule_new=False)
|
|
|
|
|
|
|
|
|
|
# 删除此session
|
|
|
|
|
del sessions[self.name]
|
|
|
|
|
return
|
2022-12-07 22:27:05 +08:00
|
|
|
|
|
|
|
|
|
# 请求回复
|
|
|
|
|
# 这个函数是阻塞的
|
|
|
|
|
def append(self, text: str) -> str:
|
|
|
|
|
self.last_interact_timestamp = int(time.time())
|
|
|
|
|
|
2023-01-14 22:36:48 +08:00
|
|
|
|
# 触发插件事件
|
|
|
|
|
if self.prompt == self.get_default_prompt():
|
|
|
|
|
args = {
|
|
|
|
|
'session_name': self.name,
|
|
|
|
|
'session': self,
|
|
|
|
|
'default_prompt': self.prompt,
|
|
|
|
|
}
|
|
|
|
|
|
2023-01-14 23:35:03 +08:00
|
|
|
|
event = pkg.plugin.host.emit(plugin_models.SessionFirstMessageReceived, **args)
|
2023-01-14 22:36:48 +08:00
|
|
|
|
if event.is_prevented_default():
|
|
|
|
|
return None
|
|
|
|
|
|
2023-01-04 17:09:57 +08:00
|
|
|
|
config = pkg.utils.context.get_config()
|
2022-12-12 17:21:02 +08:00
|
|
|
|
max_length = config.prompt_submit_length if hasattr(config, "prompt_submit_length") else 1024
|
2022-12-10 17:39:02 +08:00
|
|
|
|
|
2022-12-07 22:27:05 +08:00
|
|
|
|
# 向API请求补全
|
2023-03-02 15:31:12 +08:00
|
|
|
|
self.cut_out(text, max_length)
|
|
|
|
|
message = pkg.utils.context.get_openai_manager().request_completion(
|
|
|
|
|
self.prompt
|
|
|
|
|
)
|
|
|
|
|
|
2022-12-07 22:27:05 +08:00
|
|
|
|
# 处理回复
|
2023-03-02 15:31:12 +08:00
|
|
|
|
res_test = message
|
2022-12-07 22:27:05 +08:00
|
|
|
|
res_ans = res_test
|
|
|
|
|
|
|
|
|
|
# 去除开头可能的提示
|
|
|
|
|
res_ans_spt = res_test.split("\n\n")
|
|
|
|
|
if len(res_ans_spt) > 1:
|
|
|
|
|
del (res_ans_spt[0])
|
|
|
|
|
res_ans = '\n\n'.join(res_ans_spt)
|
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
if config.completion_api_params['model'] in pkg.openai.modelmgr.CHAT_COMPLETION_MODELS:
|
|
|
|
|
self.prompt.append({'role':'assistant', 'content':res_ans})
|
|
|
|
|
elif config.completion_api_params['model'] in pkg.openai.modelmgr.COMPLETION_MODELS:
|
|
|
|
|
self.prompt.append({'role':'', 'content':res_ans})
|
2022-12-08 21:58:02 +08:00
|
|
|
|
|
|
|
|
|
if self.just_switched_to_exist_session:
|
|
|
|
|
self.just_switched_to_exist_session = False
|
|
|
|
|
self.set_ongoing()
|
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
return res_ans if res_ans[0]!='\n' else res_ans[1:]
|
2022-12-07 22:27:05 +08:00
|
|
|
|
|
2023-01-07 00:08:22 +08:00
|
|
|
|
# 删除上一回合并返回上一回合的问题
|
|
|
|
|
def undo(self) -> str:
|
|
|
|
|
self.last_interact_timestamp = int(time.time())
|
|
|
|
|
|
|
|
|
|
# 删除上一回合
|
2023-03-02 15:31:12 +08:00
|
|
|
|
if self.prompt[-1]['role'] != 'user':
|
|
|
|
|
res = self.prompt[-1]['content']
|
|
|
|
|
self.prompt.remove(self.prompt[-2])
|
|
|
|
|
else:
|
|
|
|
|
res = self.prompt[-2]['content']
|
|
|
|
|
self.prompt.remove(self.prompt[-1])
|
2023-01-07 00:08:22 +08:00
|
|
|
|
|
|
|
|
|
# 返回上一回合的问题
|
2023-03-02 15:31:12 +08:00
|
|
|
|
return res
|
2022-12-09 16:17:04 +08:00
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
# 构建对话体
|
|
|
|
|
def cut_out(self, msg: str, max_tokens: int) -> str:
|
2022-12-09 16:17:04 +08:00
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
if len(msg) > max_tokens:
|
|
|
|
|
msg = msg[:max_tokens]
|
2022-12-09 16:17:04 +08:00
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
self.prompt.append({
|
|
|
|
|
'role': 'user',
|
|
|
|
|
'content': msg
|
|
|
|
|
})
|
2022-12-10 17:39:02 +08:00
|
|
|
|
|
2023-03-02 15:31:12 +08:00
|
|
|
|
logging.debug('cut_out: {}'.format(msg))
|
2022-12-09 16:17:04 +08:00
|
|
|
|
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# 持久化session
|
2022-12-07 22:27:05 +08:00
|
|
|
|
def persistence(self):
|
2023-01-04 21:46:01 +08:00
|
|
|
|
if self.prompt == self.get_default_prompt():
|
2022-12-08 14:28:46 +08:00
|
|
|
|
return
|
|
|
|
|
|
2023-01-01 23:18:32 +08:00
|
|
|
|
db_inst = pkg.utils.context.get_database_manager()
|
2022-12-08 00:41:35 +08:00
|
|
|
|
|
|
|
|
|
name_spt = self.name.split('_')
|
|
|
|
|
|
|
|
|
|
subject_type = name_spt[0]
|
|
|
|
|
subject_number = int(name_spt[1])
|
|
|
|
|
|
|
|
|
|
db_inst.persistence_session(subject_type, subject_number, self.create_timestamp, self.last_interact_timestamp,
|
2023-03-02 15:31:12 +08:00
|
|
|
|
json.dumps(self.prompt))
|
2022-12-08 12:06:04 +08:00
|
|
|
|
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# 重置session
|
2023-02-19 11:46:12 +08:00
|
|
|
|
def reset(self, explicit: bool = False, expired: bool = False, schedule_new: bool = True, use_prompt: str = None):
|
2023-03-02 15:31:12 +08:00
|
|
|
|
if self.prompt[-1]['role'] != "system":
|
2022-12-08 12:06:04 +08:00
|
|
|
|
self.persistence()
|
2022-12-08 13:22:54 +08:00
|
|
|
|
if explicit:
|
2023-01-14 22:36:48 +08:00
|
|
|
|
# 触发插件事件
|
|
|
|
|
args = {
|
|
|
|
|
'session_name': self.name,
|
|
|
|
|
'session': self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 此事件不支持阻止默认行为
|
|
|
|
|
_ = pkg.plugin.host.emit(plugin_models.SessionExplicitReset, **args)
|
|
|
|
|
|
2023-01-01 23:18:32 +08:00
|
|
|
|
pkg.utils.context.get_database_manager().explicit_close_session(self.name, self.create_timestamp)
|
2022-12-09 16:17:50 +08:00
|
|
|
|
|
|
|
|
|
if expired:
|
2023-01-01 23:18:32 +08:00
|
|
|
|
pkg.utils.context.get_database_manager().set_session_expired(self.name, self.create_timestamp)
|
2023-02-19 11:46:12 +08:00
|
|
|
|
self.prompt = self.get_default_prompt(use_prompt)
|
2022-12-08 12:06:04 +08:00
|
|
|
|
self.create_timestamp = int(time.time())
|
2022-12-08 14:28:46 +08:00
|
|
|
|
self.last_interact_timestamp = int(time.time())
|
2022-12-08 21:58:02 +08:00
|
|
|
|
self.just_switched_to_exist_session = False
|
|
|
|
|
|
2022-12-13 00:14:09 +08:00
|
|
|
|
# self.response_lock = threading.Lock()
|
2022-12-12 22:04:38 +08:00
|
|
|
|
|
2022-12-09 16:17:50 +08:00
|
|
|
|
if schedule_new:
|
|
|
|
|
self.schedule()
|
|
|
|
|
|
2022-12-08 21:58:02 +08:00
|
|
|
|
# 将本session的数据库状态设置为on_going
|
|
|
|
|
def set_ongoing(self):
|
2023-01-01 23:18:32 +08:00
|
|
|
|
pkg.utils.context.get_database_manager().set_session_ongoing(self.name, self.create_timestamp)
|
2022-12-08 12:06:04 +08:00
|
|
|
|
|
2022-12-08 14:28:46 +08:00
|
|
|
|
# 切换到上一个session
|
2022-12-08 12:06:04 +08:00
|
|
|
|
def last_session(self):
|
2023-01-01 23:18:32 +08:00
|
|
|
|
last_one = pkg.utils.context.get_database_manager().last_session(self.name, self.last_interact_timestamp)
|
2022-12-08 14:28:46 +08:00
|
|
|
|
if last_one is None:
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
self.persistence()
|
|
|
|
|
|
|
|
|
|
self.create_timestamp = last_one['create_timestamp']
|
|
|
|
|
self.last_interact_timestamp = last_one['last_interact_timestamp']
|
2023-03-02 15:31:12 +08:00
|
|
|
|
self.prompt = json.loads(last_one['prompt'])
|
2022-12-08 21:58:02 +08:00
|
|
|
|
|
2022-12-12 22:04:38 +08:00
|
|
|
|
self.just_switched_to_exist_session = True
|
2022-12-08 14:28:46 +08:00
|
|
|
|
return self
|
2022-12-08 12:06:04 +08:00
|
|
|
|
|
2022-12-11 16:10:12 +08:00
|
|
|
|
# 切换到下一个session
|
2022-12-08 12:06:04 +08:00
|
|
|
|
def next_session(self):
|
2023-01-01 23:18:32 +08:00
|
|
|
|
next_one = pkg.utils.context.get_database_manager().next_session(self.name, self.last_interact_timestamp)
|
2022-12-08 14:28:46 +08:00
|
|
|
|
if next_one is None:
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
self.persistence()
|
|
|
|
|
|
|
|
|
|
self.create_timestamp = next_one['create_timestamp']
|
|
|
|
|
self.last_interact_timestamp = next_one['last_interact_timestamp']
|
2023-03-02 15:31:12 +08:00
|
|
|
|
self.prompt = json.loads(next_one['prompt'])
|
2022-12-08 21:58:02 +08:00
|
|
|
|
|
2022-12-12 22:04:38 +08:00
|
|
|
|
self.just_switched_to_exist_session = True
|
2022-12-08 14:28:46 +08:00
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def list_history(self, capacity: int = 10, page: int = 0):
|
2023-03-02 15:31:12 +08:00
|
|
|
|
return pkg.utils.context.get_database_manager().list_history(self.name, capacity, page)
|
2022-12-27 22:52:53 +08:00
|
|
|
|
|
|
|
|
|
def draw_image(self, prompt: str):
|
2023-01-01 23:18:32 +08:00
|
|
|
|
return pkg.utils.context.get_openai_manager().request_image(prompt)
|