添加线程控制类,修改main结构,修改启动流程

This commit is contained in:
LINSTCL 2023-03-08 15:21:37 +08:00
parent 2933d4843f
commit 77076f3bdd
7 changed files with 231 additions and 70 deletions

View File

@ -208,7 +208,9 @@ alter_tip_message = '出错了,请稍后再试'
# 机器人线程池大小 # 机器人线程池大小
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃 # 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
# 如果你不清楚该参数的意义,请不要更改 # 如果你不清楚该参数的意义,请不要更改
pool_num = 10 sys_pool_num = 8
admin_pool_num = 2
user_pool_num = 3
# 每个会话的过期时间,单位为秒 # 每个会话的过期时间,单位为秒
# 默认值20分钟 # 默认值20分钟

104
main.py
View File

@ -23,7 +23,7 @@ import colorlog
import requests import requests
import websockets.exceptions import websockets.exceptions
from urllib3.exceptions import InsecureRequestWarning from urllib3.exceptions import InsecureRequestWarning
import pkg.utils.context
sys.path.append(".") sys.path.append(".")
@ -74,11 +74,8 @@ def init_runtime_log_file():
def reset_logging(): def reset_logging():
global log_file_name global log_file_name
assert os.path.exists('config.py')
config = importlib.import_module('config') config = pkg.utils.context.get_config()
import pkg.utils.context
if pkg.utils.context.context['logger_handler'] is not None: if pkg.utils.context.context['logger_handler'] is not None:
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler']) logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
@ -106,12 +103,13 @@ def reset_logging():
return sh return sh
def main(first_time_init=False): def start(first_time_init=False):
"""启动流程reload之后会被执行""" """启动流程reload之后会被执行"""
global known_exception_caught global known_exception_caught
import pkg.utils.context
import config config = pkg.utils.context.get_config()
# 更新openai库到最新版本 # 更新openai库到最新版本
if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies: if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies:
print("正在更新依赖库,请等待...") print("正在更新依赖库,请等待...")
@ -126,31 +124,10 @@ def main(first_time_init=False):
known_exception_caught = False known_exception_caught = False
try: try:
# 导入config.py
assert os.path.exists('config.py')
config = importlib.import_module('config')
init_runtime_log_file() init_runtime_log_file()
sh = reset_logging() sh = reset_logging()
# 配置完整性校验
is_integrity = True
config_template = importlib.import_module('config-template')
for key in dir(config_template):
if not key.startswith("__") and not hasattr(config, key):
setattr(config, key, getattr(config_template, key))
logging.warning("[{}]不存在".format(key))
is_integrity = False
if not is_integrity:
logging.warning("配置文件不完整请依据config-template.py检查config.py")
logging.warning("以上配置已被设为默认值将在5秒后继续启动... ")
time.sleep(5)
import pkg.utils.context
pkg.utils.context.set_config(config)
# 检查是否设置了管理员 # 检查是否设置了管理员
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0): if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段") # logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
@ -197,7 +174,7 @@ def main(first_time_init=False):
# 初始化qq机器人 # 初始化qq机器人
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config, qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
timeout=config.process_message_timeout, retry=config.retry_times, timeout=config.process_message_timeout, retry=config.retry_times,
first_time_init=first_time_init, pool_num=config.pool_num) first_time_init=first_time_init)
# 加载插件 # 加载插件
import pkg.plugin.host import pkg.plugin.host
@ -252,10 +229,10 @@ def main(first_time_init=False):
known_exception_caught = True known_exception_caught = True
raise e raise e
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True) pkg.utils.context.get_thread_ctl().submit_sys_task(
qq_bot_thread.start() run_bot_wrapper
)
finally: finally:
time.sleep(12)
if first_time_init: if first_time_init:
if not known_exception_caught: if not known_exception_caught:
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 ' logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
@ -294,11 +271,8 @@ def main(first_time_init=False):
except Exception as e: except Exception as e:
logging.warning("检查更新失败:{}".format(e)) logging.warning("检查更新失败:{}".format(e))
return qqbot
def stop(): def stop():
import pkg.utils.context
import pkg.qqbot.manager import pkg.qqbot.manager
import pkg.openai.session import pkg.openai.session
try: try:
@ -316,14 +290,30 @@ def stop():
if not isinstance(e, KeyboardInterrupt): if not isinstance(e, KeyboardInterrupt):
raise e raise e
# 临时函数用于加载config和上下文未来统一放在config类
if __name__ == '__main__': def load_config():
# 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序 #存在性校验
if not os.path.exists('config.py'): if not os.path.exists('config.py'):
shutil.copy('config-template.py', 'config.py') shutil.copy('config-template.py', 'config.py')
print('请先在config.py中填写配置') print('请先在config.py中填写配置')
sys.exit(0) sys.exit(0)
#完整性校验
is_integrity = True
config_template = importlib.import_module('config-template')
config = importlib.import_module('config')
for key in dir(config_template):
if not key.startswith("__") and not hasattr(config, key):
setattr(config, key, getattr(config_template, key))
logging.warning("[{}]不存在".format(key))
is_integrity = False
if not is_integrity:
logging.warning("配置文件不完整请依据config-template.py检查config.py")
logging.warning("以上配置已被设为默认值将在5秒后继续启动... ")
time.sleep(5)
#context配置
pkg.utils.context.set_config(config)
def check_file():
# 检查是否有banlist.py,如果没有就把banlist-template.py复制一份 # 检查是否有banlist.py,如果没有就把banlist-template.py复制一份
if not os.path.exists('banlist.py'): if not os.path.exists('banlist.py'):
shutil.copy('banlist-template.py', 'banlist.py') shutil.copy('banlist-template.py', 'banlist.py')
@ -342,6 +332,24 @@ if __name__ == '__main__':
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
def main():
# 加载配置
load_config()
config = pkg.utils.context.get_config()
# 初始化相关文件
check_file()
# 配置线程池
from pkg.utils import ThreadCtl
thread_ctl = ThreadCtl(
sys_pool_num = config.sys_pool_num,
admin_pool_num = config.admin_pool_num,
user_pool_num = config.user_pool_num
)
pkg.utils.context.set_thread_ctl(thread_ctl)
# 控制台指令处理
if len(sys.argv) > 1 and sys.argv[1] == 'init_db': if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
init_db() init_db()
sys.exit(0) sys.exit(0)
@ -352,19 +360,27 @@ if __name__ == '__main__':
updater.update_all(cli=True) updater.update_all(cli=True)
sys.exit(0) sys.exit(0)
# 不知道干啥的
# import pkg.utils.configmgr # import pkg.utils.configmgr
# #
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False) # pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
requests.packages.urllib3.disable_warnings(InsecureRequestWarning) requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
qqbot = main(True) pkg.utils.context.get_thread_ctl().submit_sys_task(
start,
True
)
import pkg.utils.context # 主线程循环
while True: while True:
try: try:
time.sleep(10) time.sleep(0xFF)
except KeyboardInterrupt: except:
stop() stop()
pkg.utils.context.get_thread_ctl().shutdown()
print("程序退出") print("退出")
sys.exit(0) sys.exit(0)
if __name__ == '__main__':
main()

View File

@ -2,7 +2,6 @@ import asyncio
import json import json
import os import os
import threading import threading
from concurrent.futures import ThreadPoolExecutor
import mirai.models.bus import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
@ -66,9 +65,6 @@ def random_responding():
class QQBotManager: class QQBotManager:
retry = 3 retry = 3
#线程池控制
pool = None
bot: Mirai = None bot: Mirai = None
reply_filter = None reply_filter = None
@ -78,14 +74,10 @@ class QQBotManager:
ban_person = [] ban_person = []
ban_group = [] ban_group = []
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True): def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
self.timeout = timeout self.timeout = timeout
self.retry = retry self.retry = retry
self.pool_num = pool_num
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
logging.debug("Registered thread pool Size:{}".format(pool_num))
# 加载禁用列表 # 加载禁用列表
if os.path.exists("banlist.py"): if os.path.exists("banlist.py"):
import banlist import banlist
@ -138,7 +130,10 @@ class QQBotManager:
self.on_person_message(event) self.on_person_message(event)
self.go(friend_message_handler, event) pkg.utils.context.get_thread_ctl().submit_user_task(
friend_message_handler,
event
)
@self.bot.on(StrangerMessage) @self.bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage): async def on_stranger_message(event: StrangerMessage):
@ -158,7 +153,10 @@ class QQBotManager:
self.on_person_message(event) self.on_person_message(event)
self.go(stranger_message_handler, event) pkg.utils.context.get_thread_ctl().submit_user_task(
stranger_message_handler,
event
)
@self.bot.on(GroupMessage) @self.bot.on(GroupMessage)
async def on_group_message(event: GroupMessage): async def on_group_message(event: GroupMessage):
@ -178,7 +176,10 @@ class QQBotManager:
self.on_group_message(event) self.on_group_message(event)
self.go(group_message_handler, event) pkg.utils.context.get_thread_ctl().submit_user_task(
group_message_handler,
event
)
def unsubscribe_all(): def unsubscribe_all():
"""取消所有订阅 """取消所有订阅

View File

@ -0,0 +1 @@
from .threadctl import ThreadCtl

View File

@ -1,50 +1,91 @@
import threading
context = { context = {
'inst': { 'inst': {
'database.manager.DatabaseManager': None, 'database.manager.DatabaseManager': None,
'openai.manager.OpenAIInteract': None, 'openai.manager.OpenAIInteract': None,
'qqbot.manager.QQBotManager': None, 'qqbot.manager.QQBotManager': None,
}, },
'pool_ctl': None,
'logger_handler': None, 'logger_handler': None,
'config': None, 'config': None,
'plugin_host': None, 'plugin_host': None,
} }
context_lock = threading.Lock()
### context耦合度非常高需要大改 ###
def set_config(inst): def set_config(inst):
context_lock.acquire()
context['config'] = inst context['config'] = inst
context_lock.release()
def get_config(): def get_config():
return context['config'] context_lock.acquire()
t = context['config']
context_lock.release()
return t
def set_database_manager(inst): def set_database_manager(inst):
context_lock.acquire()
context['inst']['database.manager.DatabaseManager'] = inst context['inst']['database.manager.DatabaseManager'] = inst
context_lock.release()
def get_database_manager(): def get_database_manager():
return context['inst']['database.manager.DatabaseManager'] context_lock.acquire()
t = context['inst']['database.manager.DatabaseManager']
context_lock.release()
return t
def set_openai_manager(inst): def set_openai_manager(inst):
context_lock.acquire()
context['inst']['openai.manager.OpenAIInteract'] = inst context['inst']['openai.manager.OpenAIInteract'] = inst
context_lock.release()
def get_openai_manager(): def get_openai_manager():
return context['inst']['openai.manager.OpenAIInteract'] context_lock.acquire()
t = context['inst']['openai.manager.OpenAIInteract']
context_lock.release()
return t
def set_qqbot_manager(inst): def set_qqbot_manager(inst):
context_lock.acquire()
context['inst']['qqbot.manager.QQBotManager'] = inst context['inst']['qqbot.manager.QQBotManager'] = inst
context_lock.release()
def get_qqbot_manager(): def get_qqbot_manager():
return context['inst']['qqbot.manager.QQBotManager'] context_lock.acquire()
t = context['inst']['qqbot.manager.QQBotManager']
context_lock.release()
return t
def set_plugin_host(inst): def set_plugin_host(inst):
context_lock.acquire()
context['plugin_host'] = inst context['plugin_host'] = inst
context_lock.release()
def get_plugin_host(): def get_plugin_host():
return context['plugin_host'] context_lock.acquire()
t = context['plugin_host']
context_lock.release()
return t
def set_thread_ctl(inst):
context_lock.acquire()
context['pool_ctl'] = inst
context_lock.release()
from pkg.utils import ThreadCtl
def get_thread_ctl() -> ThreadCtl:
context_lock.acquire()
t = context['pool_ctl']
context_lock.release()
return t

View File

@ -3,7 +3,7 @@ import threading
import importlib import importlib
import pkgutil import pkgutil
import pkg.utils.context import pkg.utils.context as context
import pkg.plugin.host import pkg.plugin.host
@ -22,20 +22,20 @@ def walk(module, prefix='', path_prefix=''):
def reload_all(notify=True): def reload_all(notify=True):
# 解除bot的事件注册 # 解除bot的事件注册
import pkg import pkg
pkg.utils.context.get_qqbot_manager().unsubscribe_all() context.get_qqbot_manager().unsubscribe_all()
# 执行关闭流程 # 执行关闭流程
logging.info("执行程序关闭流程") logging.info("执行程序关闭流程")
import main import main
main.stop() main.stop()
# 重载所有模块 # 重载所有模块
pkg.utils.context.context['exceeded_keys'] = pkg.utils.context.get_openai_manager().key_mgr.exceeded context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded
context = pkg.utils.context.context this_context = context.context
walk(pkg) walk(pkg)
importlib.reload(__import__('config')) importlib.reload(__import__('config'))
importlib.reload(__import__('main')) importlib.reload(__import__('main'))
importlib.reload(__import__('banlist')) importlib.reload(__import__('banlist'))
pkg.utils.context.context = context context.context = this_context
# 重载插件 # 重载插件
import plugins import plugins
@ -43,8 +43,15 @@ def reload_all(notify=True):
# 执行启动流程 # 执行启动流程
logging.info("执行程序启动流程") logging.info("执行程序启动流程")
threading.Thread(target=main.main, args=(False,), daemon=False).start() context.get_thread_ctl().reload(
admin_pool_num=context.get_config().admin_pool_num,
user_pool_num=context.get_config().user_pool_num
)
context.get_thread_ctl().submit_sys_task(
main.start,
False
)
logging.info('程序启动完成') logging.info('程序启动完成')
if notify: if notify:
pkg.utils.context.get_qqbot_manager().notify_admin("重载完成") context.get_qqbot_manager().notify_admin("重载完成")

93
pkg/utils/threadctl.py Normal file
View File

@ -0,0 +1,93 @@
from concurrent.futures import ThreadPoolExecutor, Future
import threading, time
class Pool():
'''
线程池结构
'''
pool_num:int = None
ctl:ThreadPoolExecutor = None
task_list:list = None
task_list_lock:threading.Lock = None
monitor_type = True
def __init__(self, pool_num):
self.pool_num = pool_num
self.ctl = ThreadPoolExecutor(max_workers = self.pool_num)
self.task_list = []
self.task_list_lock = threading.Lock()
def __thread_monitor__(self):
while self.monitor_type:
for t in self.task_list:
if not t.done():
continue
try:
self.task_list.pop(self.task_list.index(t))
except:
continue
time.sleep(1)
class ThreadCtl():
def __init__(self, sys_pool_num, admin_pool_num, user_pool_num):
'''
线程池控制类
sys_pool_num分配系统使用的线程池数量(>=5)
admin_pool_num用于处理管理员消息的线程池数量(>=1)
user_pool_num分配用于处理用户消息的线程池的数量(>=1)
'''
if sys_pool_num < 5:
raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num))
if admin_pool_num < 1:
raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num))
if user_pool_num < 1:
raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num))
self.__sys_pool__ = Pool(sys_pool_num)
self.__admin_pool__ = Pool(admin_pool_num)
self.__user_pool__ = Pool(user_pool_num)
self.submit_sys_task(self.__sys_pool__.__thread_monitor__)
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
self.submit_sys_task(self.__user_pool__.__thread_monitor__)
def __submit__(self, pool:Pool, fn, /, *args, **kwargs ):
t = pool.ctl.submit(fn, *args, **kwargs)
pool.task_list_lock.acquire()
pool.task_list.append(t)
pool.task_list_lock.release()
return t
def submit_sys_task(self, fn, /, *args, **kwargs):
return self.__submit__(
self.__sys_pool__,
fn, *args, **kwargs
)
def submit_admin_task(self, fn, /, *args, **kwargs):
return self.__submit__(
self.__admin_pool__,
fn, *args, **kwargs
)
def submit_user_task(self, fn, /, *args, **kwargs):
return self.__submit__(
self.__user_pool__,
fn, *args, **kwargs
)
def shutdown(self):
self.__user_pool__.ctl.shutdown(cancel_futures=True)
self.__user_pool__.monitor_type = False
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
self.__admin_pool__.monitor_type = False
self.__sys_pool__.monitor_type = False
self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False)
def reload(self, admin_pool_num, user_pool_num):
self.__user_pool__.ctl.shutdown(cancel_futures=True)
self.__user_pool__.monitor_type = False
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
self.__admin_pool__.monitor_type = False
self.__admin_pool__ = Pool(admin_pool_num)
self.__user_pool__ = Pool(user_pool_num)
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
self.submit_sys_task(self.__user_pool__.__thread_monitor__)