添加线程控制类,修改main结构,修改启动流程

This commit is contained in:
LINSTCL 2023-03-08 15:21:37 +08:00
parent 2933d4843f
commit 77076f3bdd
7 changed files with 231 additions and 70 deletions

View File

@ -208,7 +208,9 @@ alter_tip_message = '出错了,请稍后再试'
# 机器人线程池大小
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
# 如果你不清楚该参数的意义,请不要更改
pool_num = 10
sys_pool_num = 8
admin_pool_num = 2
user_pool_num = 3
# 每个会话的过期时间,单位为秒
# 默认值20分钟

104
main.py
View File

@ -23,7 +23,7 @@ import colorlog
import requests
import websockets.exceptions
from urllib3.exceptions import InsecureRequestWarning
import pkg.utils.context
sys.path.append(".")
@ -74,11 +74,8 @@ def init_runtime_log_file():
def reset_logging():
global log_file_name
assert os.path.exists('config.py')
config = importlib.import_module('config')
import pkg.utils.context
config = pkg.utils.context.get_config()
if pkg.utils.context.context['logger_handler'] is not None:
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
@ -106,12 +103,13 @@ def reset_logging():
return sh
def main(first_time_init=False):
def start(first_time_init=False):
"""启动流程reload之后会被执行"""
global known_exception_caught
import pkg.utils.context
import config
config = pkg.utils.context.get_config()
# 更新openai库到最新版本
if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies:
print("正在更新依赖库,请等待...")
@ -126,31 +124,10 @@ def main(first_time_init=False):
known_exception_caught = False
try:
# 导入config.py
assert os.path.exists('config.py')
config = importlib.import_module('config')
init_runtime_log_file()
sh = reset_logging()
# 配置完整性校验
is_integrity = True
config_template = importlib.import_module('config-template')
for key in dir(config_template):
if not key.startswith("__") and not hasattr(config, key):
setattr(config, key, getattr(config_template, key))
logging.warning("[{}]不存在".format(key))
is_integrity = False
if not is_integrity:
logging.warning("配置文件不完整请依据config-template.py检查config.py")
logging.warning("以上配置已被设为默认值将在5秒后继续启动... ")
time.sleep(5)
import pkg.utils.context
pkg.utils.context.set_config(config)
# 检查是否设置了管理员
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
@ -197,7 +174,7 @@ def main(first_time_init=False):
# 初始化qq机器人
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
timeout=config.process_message_timeout, retry=config.retry_times,
first_time_init=first_time_init, pool_num=config.pool_num)
first_time_init=first_time_init)
# 加载插件
import pkg.plugin.host
@ -252,10 +229,10 @@ def main(first_time_init=False):
known_exception_caught = True
raise e
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True)
qq_bot_thread.start()
pkg.utils.context.get_thread_ctl().submit_sys_task(
run_bot_wrapper
)
finally:
time.sleep(12)
if first_time_init:
if not known_exception_caught:
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
@ -294,11 +271,8 @@ def main(first_time_init=False):
except Exception as e:
logging.warning("检查更新失败:{}".format(e))
return qqbot
def stop():
import pkg.utils.context
import pkg.qqbot.manager
import pkg.openai.session
try:
@ -316,14 +290,30 @@ def stop():
if not isinstance(e, KeyboardInterrupt):
raise e
if __name__ == '__main__':
# 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序
# 临时函数用于加载config和上下文未来统一放在config类
def load_config():
#存在性校验
if not os.path.exists('config.py'):
shutil.copy('config-template.py', 'config.py')
print('请先在config.py中填写配置')
sys.exit(0)
#完整性校验
is_integrity = True
config_template = importlib.import_module('config-template')
config = importlib.import_module('config')
for key in dir(config_template):
if not key.startswith("__") and not hasattr(config, key):
setattr(config, key, getattr(config_template, key))
logging.warning("[{}]不存在".format(key))
is_integrity = False
if not is_integrity:
logging.warning("配置文件不完整请依据config-template.py检查config.py")
logging.warning("以上配置已被设为默认值将在5秒后继续启动... ")
time.sleep(5)
#context配置
pkg.utils.context.set_config(config)
def check_file():
# 检查是否有banlist.py,如果没有就把banlist-template.py复制一份
if not os.path.exists('banlist.py'):
shutil.copy('banlist-template.py', 'banlist.py')
@ -342,6 +332,24 @@ if __name__ == '__main__':
if not os.path.exists(path):
os.mkdir(path)
def main():
# 加载配置
load_config()
config = pkg.utils.context.get_config()
# 初始化相关文件
check_file()
# 配置线程池
from pkg.utils import ThreadCtl
thread_ctl = ThreadCtl(
sys_pool_num = config.sys_pool_num,
admin_pool_num = config.admin_pool_num,
user_pool_num = config.user_pool_num
)
pkg.utils.context.set_thread_ctl(thread_ctl)
# 控制台指令处理
if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
init_db()
sys.exit(0)
@ -352,19 +360,27 @@ if __name__ == '__main__':
updater.update_all(cli=True)
sys.exit(0)
# 不知道干啥的
# import pkg.utils.configmgr
#
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
qqbot = main(True)
pkg.utils.context.get_thread_ctl().submit_sys_task(
start,
True
)
import pkg.utils.context
# 主线程循环
while True:
try:
time.sleep(10)
except KeyboardInterrupt:
time.sleep(0xFF)
except:
stop()
print("程序退出")
pkg.utils.context.get_thread_ctl().shutdown()
print("退出")
sys.exit(0)
if __name__ == '__main__':
main()

View File

@ -2,7 +2,6 @@ import asyncio
import json
import os
import threading
from concurrent.futures import ThreadPoolExecutor
import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
@ -66,9 +65,6 @@ def random_responding():
class QQBotManager:
retry = 3
#线程池控制
pool = None
bot: Mirai = None
reply_filter = None
@ -78,14 +74,10 @@ class QQBotManager:
ban_person = []
ban_group = []
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
self.timeout = timeout
self.retry = retry
self.pool_num = pool_num
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
logging.debug("Registered thread pool Size:{}".format(pool_num))
# 加载禁用列表
if os.path.exists("banlist.py"):
import banlist
@ -138,7 +130,10 @@ class QQBotManager:
self.on_person_message(event)
self.go(friend_message_handler, event)
pkg.utils.context.get_thread_ctl().submit_user_task(
friend_message_handler,
event
)
@self.bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage):
@ -158,7 +153,10 @@ class QQBotManager:
self.on_person_message(event)
self.go(stranger_message_handler, event)
pkg.utils.context.get_thread_ctl().submit_user_task(
stranger_message_handler,
event
)
@self.bot.on(GroupMessage)
async def on_group_message(event: GroupMessage):
@ -178,7 +176,10 @@ class QQBotManager:
self.on_group_message(event)
self.go(group_message_handler, event)
pkg.utils.context.get_thread_ctl().submit_user_task(
group_message_handler,
event
)
def unsubscribe_all():
"""取消所有订阅

View File

@ -0,0 +1 @@
from .threadctl import ThreadCtl

View File

@ -1,50 +1,91 @@
import threading
context = {
'inst': {
'database.manager.DatabaseManager': None,
'openai.manager.OpenAIInteract': None,
'qqbot.manager.QQBotManager': None,
},
'pool_ctl': None,
'logger_handler': None,
'config': None,
'plugin_host': None,
}
context_lock = threading.Lock()
### context耦合度非常高需要大改 ###
def set_config(inst):
context_lock.acquire()
context['config'] = inst
context_lock.release()
def get_config():
return context['config']
context_lock.acquire()
t = context['config']
context_lock.release()
return t
def set_database_manager(inst):
context_lock.acquire()
context['inst']['database.manager.DatabaseManager'] = inst
context_lock.release()
def get_database_manager():
return context['inst']['database.manager.DatabaseManager']
context_lock.acquire()
t = context['inst']['database.manager.DatabaseManager']
context_lock.release()
return t
def set_openai_manager(inst):
context_lock.acquire()
context['inst']['openai.manager.OpenAIInteract'] = inst
context_lock.release()
def get_openai_manager():
return context['inst']['openai.manager.OpenAIInteract']
context_lock.acquire()
t = context['inst']['openai.manager.OpenAIInteract']
context_lock.release()
return t
def set_qqbot_manager(inst):
context_lock.acquire()
context['inst']['qqbot.manager.QQBotManager'] = inst
context_lock.release()
def get_qqbot_manager():
return context['inst']['qqbot.manager.QQBotManager']
context_lock.acquire()
t = context['inst']['qqbot.manager.QQBotManager']
context_lock.release()
return t
def set_plugin_host(inst):
context_lock.acquire()
context['plugin_host'] = inst
context_lock.release()
def get_plugin_host():
return context['plugin_host']
context_lock.acquire()
t = context['plugin_host']
context_lock.release()
return t
def set_thread_ctl(inst):
context_lock.acquire()
context['pool_ctl'] = inst
context_lock.release()
from pkg.utils import ThreadCtl
def get_thread_ctl() -> ThreadCtl:
context_lock.acquire()
t = context['pool_ctl']
context_lock.release()
return t

View File

@ -3,7 +3,7 @@ import threading
import importlib
import pkgutil
import pkg.utils.context
import pkg.utils.context as context
import pkg.plugin.host
@ -22,20 +22,20 @@ def walk(module, prefix='', path_prefix=''):
def reload_all(notify=True):
# 解除bot的事件注册
import pkg
pkg.utils.context.get_qqbot_manager().unsubscribe_all()
context.get_qqbot_manager().unsubscribe_all()
# 执行关闭流程
logging.info("执行程序关闭流程")
import main
main.stop()
# 重载所有模块
pkg.utils.context.context['exceeded_keys'] = pkg.utils.context.get_openai_manager().key_mgr.exceeded
context = pkg.utils.context.context
context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded
this_context = context.context
walk(pkg)
importlib.reload(__import__('config'))
importlib.reload(__import__('main'))
importlib.reload(__import__('banlist'))
pkg.utils.context.context = context
context.context = this_context
# 重载插件
import plugins
@ -43,8 +43,15 @@ def reload_all(notify=True):
# 执行启动流程
logging.info("执行程序启动流程")
threading.Thread(target=main.main, args=(False,), daemon=False).start()
context.get_thread_ctl().reload(
admin_pool_num=context.get_config().admin_pool_num,
user_pool_num=context.get_config().user_pool_num
)
context.get_thread_ctl().submit_sys_task(
main.start,
False
)
logging.info('程序启动完成')
if notify:
pkg.utils.context.get_qqbot_manager().notify_admin("重载完成")
context.get_qqbot_manager().notify_admin("重载完成")

93
pkg/utils/threadctl.py Normal file
View File

@ -0,0 +1,93 @@
from concurrent.futures import ThreadPoolExecutor, Future
import threading, time
class Pool():
'''
线程池结构
'''
pool_num:int = None
ctl:ThreadPoolExecutor = None
task_list:list = None
task_list_lock:threading.Lock = None
monitor_type = True
def __init__(self, pool_num):
self.pool_num = pool_num
self.ctl = ThreadPoolExecutor(max_workers = self.pool_num)
self.task_list = []
self.task_list_lock = threading.Lock()
def __thread_monitor__(self):
while self.monitor_type:
for t in self.task_list:
if not t.done():
continue
try:
self.task_list.pop(self.task_list.index(t))
except:
continue
time.sleep(1)
class ThreadCtl():
def __init__(self, sys_pool_num, admin_pool_num, user_pool_num):
'''
线程池控制类
sys_pool_num分配系统使用的线程池数量(>=5)
admin_pool_num用于处理管理员消息的线程池数量(>=1)
user_pool_num分配用于处理用户消息的线程池的数量(>=1)
'''
if sys_pool_num < 5:
raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num))
if admin_pool_num < 1:
raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num))
if user_pool_num < 1:
raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num))
self.__sys_pool__ = Pool(sys_pool_num)
self.__admin_pool__ = Pool(admin_pool_num)
self.__user_pool__ = Pool(user_pool_num)
self.submit_sys_task(self.__sys_pool__.__thread_monitor__)
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
self.submit_sys_task(self.__user_pool__.__thread_monitor__)
def __submit__(self, pool:Pool, fn, /, *args, **kwargs ):
t = pool.ctl.submit(fn, *args, **kwargs)
pool.task_list_lock.acquire()
pool.task_list.append(t)
pool.task_list_lock.release()
return t
def submit_sys_task(self, fn, /, *args, **kwargs):
return self.__submit__(
self.__sys_pool__,
fn, *args, **kwargs
)
def submit_admin_task(self, fn, /, *args, **kwargs):
return self.__submit__(
self.__admin_pool__,
fn, *args, **kwargs
)
def submit_user_task(self, fn, /, *args, **kwargs):
return self.__submit__(
self.__user_pool__,
fn, *args, **kwargs
)
def shutdown(self):
self.__user_pool__.ctl.shutdown(cancel_futures=True)
self.__user_pool__.monitor_type = False
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
self.__admin_pool__.monitor_type = False
self.__sys_pool__.monitor_type = False
self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False)
def reload(self, admin_pool_num, user_pool_num):
self.__user_pool__.ctl.shutdown(cancel_futures=True)
self.__user_pool__.monitor_type = False
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
self.__admin_pool__.monitor_type = False
self.__admin_pool__ = Pool(admin_pool_num)
self.__user_pool__ = Pool(user_pool_num)
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
self.submit_sys_task(self.__user_pool__.__thread_monitor__)