mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 11:42:44 +08:00
添加线程控制类,修改main结构,修改启动流程
This commit is contained in:
parent
2933d4843f
commit
77076f3bdd
|
@ -208,7 +208,9 @@ alter_tip_message = '出错了,请稍后再试'
|
||||||
# 机器人线程池大小
|
# 机器人线程池大小
|
||||||
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
|
# 该参数决定机器人可以同时处理几个人的消息,超出线程池数量的请求会被阻塞,不会被丢弃
|
||||||
# 如果你不清楚该参数的意义,请不要更改
|
# 如果你不清楚该参数的意义,请不要更改
|
||||||
pool_num = 10
|
sys_pool_num = 8
|
||||||
|
admin_pool_num = 2
|
||||||
|
user_pool_num = 3
|
||||||
|
|
||||||
# 每个会话的过期时间,单位为秒
|
# 每个会话的过期时间,单位为秒
|
||||||
# 默认值20分钟
|
# 默认值20分钟
|
||||||
|
|
104
main.py
104
main.py
|
@ -23,7 +23,7 @@ import colorlog
|
||||||
import requests
|
import requests
|
||||||
import websockets.exceptions
|
import websockets.exceptions
|
||||||
from urllib3.exceptions import InsecureRequestWarning
|
from urllib3.exceptions import InsecureRequestWarning
|
||||||
|
import pkg.utils.context
|
||||||
|
|
||||||
sys.path.append(".")
|
sys.path.append(".")
|
||||||
|
|
||||||
|
@ -74,11 +74,8 @@ def init_runtime_log_file():
|
||||||
|
|
||||||
def reset_logging():
|
def reset_logging():
|
||||||
global log_file_name
|
global log_file_name
|
||||||
assert os.path.exists('config.py')
|
|
||||||
|
|
||||||
config = importlib.import_module('config')
|
config = pkg.utils.context.get_config()
|
||||||
|
|
||||||
import pkg.utils.context
|
|
||||||
|
|
||||||
if pkg.utils.context.context['logger_handler'] is not None:
|
if pkg.utils.context.context['logger_handler'] is not None:
|
||||||
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
logging.getLogger().removeHandler(pkg.utils.context.context['logger_handler'])
|
||||||
|
@ -106,12 +103,13 @@ def reset_logging():
|
||||||
return sh
|
return sh
|
||||||
|
|
||||||
|
|
||||||
def main(first_time_init=False):
|
def start(first_time_init=False):
|
||||||
"""启动流程,reload之后会被执行"""
|
"""启动流程,reload之后会被执行"""
|
||||||
|
|
||||||
global known_exception_caught
|
global known_exception_caught
|
||||||
|
import pkg.utils.context
|
||||||
|
|
||||||
import config
|
config = pkg.utils.context.get_config()
|
||||||
# 更新openai库到最新版本
|
# 更新openai库到最新版本
|
||||||
if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies:
|
if not hasattr(config, 'upgrade_dependencies') or config.upgrade_dependencies:
|
||||||
print("正在更新依赖库,请等待...")
|
print("正在更新依赖库,请等待...")
|
||||||
|
@ -126,31 +124,10 @@ def main(first_time_init=False):
|
||||||
|
|
||||||
known_exception_caught = False
|
known_exception_caught = False
|
||||||
try:
|
try:
|
||||||
# 导入config.py
|
|
||||||
assert os.path.exists('config.py')
|
|
||||||
|
|
||||||
config = importlib.import_module('config')
|
|
||||||
|
|
||||||
init_runtime_log_file()
|
init_runtime_log_file()
|
||||||
|
|
||||||
sh = reset_logging()
|
sh = reset_logging()
|
||||||
|
|
||||||
# 配置完整性校验
|
|
||||||
is_integrity = True
|
|
||||||
config_template = importlib.import_module('config-template')
|
|
||||||
for key in dir(config_template):
|
|
||||||
if not key.startswith("__") and not hasattr(config, key):
|
|
||||||
setattr(config, key, getattr(config_template, key))
|
|
||||||
logging.warning("[{}]不存在".format(key))
|
|
||||||
is_integrity = False
|
|
||||||
if not is_integrity:
|
|
||||||
logging.warning("配置文件不完整,请依据config-template.py检查config.py")
|
|
||||||
logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ")
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
import pkg.utils.context
|
|
||||||
pkg.utils.context.set_config(config)
|
|
||||||
|
|
||||||
# 检查是否设置了管理员
|
# 检查是否设置了管理员
|
||||||
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
if not (hasattr(config, 'admin_qq') and config.admin_qq != 0):
|
||||||
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
# logging.warning("未设置管理员QQ,管理员权限指令及运行告警将无法使用,如需设置请修改config.py中的admin_qq字段")
|
||||||
|
@ -197,7 +174,7 @@ def main(first_time_init=False):
|
||||||
# 初始化qq机器人
|
# 初始化qq机器人
|
||||||
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
|
qqbot = pkg.qqbot.manager.QQBotManager(mirai_http_api_config=config.mirai_http_api_config,
|
||||||
timeout=config.process_message_timeout, retry=config.retry_times,
|
timeout=config.process_message_timeout, retry=config.retry_times,
|
||||||
first_time_init=first_time_init, pool_num=config.pool_num)
|
first_time_init=first_time_init)
|
||||||
|
|
||||||
# 加载插件
|
# 加载插件
|
||||||
import pkg.plugin.host
|
import pkg.plugin.host
|
||||||
|
@ -252,10 +229,10 @@ def main(first_time_init=False):
|
||||||
known_exception_caught = True
|
known_exception_caught = True
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
qq_bot_thread = threading.Thread(target=run_bot_wrapper, args=(), daemon=True)
|
pkg.utils.context.get_thread_ctl().submit_sys_task(
|
||||||
qq_bot_thread.start()
|
run_bot_wrapper
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
time.sleep(12)
|
|
||||||
if first_time_init:
|
if first_time_init:
|
||||||
if not known_exception_caught:
|
if not known_exception_caught:
|
||||||
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
|
logging.info('程序启动完成,如长时间未显示 ”成功登录到账号xxxxx“ ,并且不回复消息,请查看 '
|
||||||
|
@ -294,11 +271,8 @@ def main(first_time_init=False):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("检查更新失败:{}".format(e))
|
logging.warning("检查更新失败:{}".format(e))
|
||||||
|
|
||||||
return qqbot
|
|
||||||
|
|
||||||
|
|
||||||
def stop():
|
def stop():
|
||||||
import pkg.utils.context
|
|
||||||
import pkg.qqbot.manager
|
import pkg.qqbot.manager
|
||||||
import pkg.openai.session
|
import pkg.openai.session
|
||||||
try:
|
try:
|
||||||
|
@ -316,14 +290,30 @@ def stop():
|
||||||
if not isinstance(e, KeyboardInterrupt):
|
if not isinstance(e, KeyboardInterrupt):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
# 临时函数,用于加载config和上下文,未来统一放在config类
|
||||||
if __name__ == '__main__':
|
def load_config():
|
||||||
# 检查是否有config.py,如果没有就把config-template.py复制一份,并退出程序
|
#存在性校验
|
||||||
if not os.path.exists('config.py'):
|
if not os.path.exists('config.py'):
|
||||||
shutil.copy('config-template.py', 'config.py')
|
shutil.copy('config-template.py', 'config.py')
|
||||||
print('请先在config.py中填写配置')
|
print('请先在config.py中填写配置')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
#完整性校验
|
||||||
|
is_integrity = True
|
||||||
|
config_template = importlib.import_module('config-template')
|
||||||
|
config = importlib.import_module('config')
|
||||||
|
for key in dir(config_template):
|
||||||
|
if not key.startswith("__") and not hasattr(config, key):
|
||||||
|
setattr(config, key, getattr(config_template, key))
|
||||||
|
logging.warning("[{}]不存在".format(key))
|
||||||
|
is_integrity = False
|
||||||
|
if not is_integrity:
|
||||||
|
logging.warning("配置文件不完整,请依据config-template.py检查config.py")
|
||||||
|
logging.warning("以上配置已被设为默认值,将在5秒后继续启动... ")
|
||||||
|
time.sleep(5)
|
||||||
|
#context配置
|
||||||
|
pkg.utils.context.set_config(config)
|
||||||
|
|
||||||
|
def check_file():
|
||||||
# 检查是否有banlist.py,如果没有就把banlist-template.py复制一份
|
# 检查是否有banlist.py,如果没有就把banlist-template.py复制一份
|
||||||
if not os.path.exists('banlist.py'):
|
if not os.path.exists('banlist.py'):
|
||||||
shutil.copy('banlist-template.py', 'banlist.py')
|
shutil.copy('banlist-template.py', 'banlist.py')
|
||||||
|
@ -342,6 +332,24 @@ if __name__ == '__main__':
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
os.mkdir(path)
|
os.mkdir(path)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 加载配置
|
||||||
|
load_config()
|
||||||
|
config = pkg.utils.context.get_config()
|
||||||
|
|
||||||
|
# 初始化相关文件
|
||||||
|
check_file()
|
||||||
|
|
||||||
|
# 配置线程池
|
||||||
|
from pkg.utils import ThreadCtl
|
||||||
|
thread_ctl = ThreadCtl(
|
||||||
|
sys_pool_num = config.sys_pool_num,
|
||||||
|
admin_pool_num = config.admin_pool_num,
|
||||||
|
user_pool_num = config.user_pool_num
|
||||||
|
)
|
||||||
|
pkg.utils.context.set_thread_ctl(thread_ctl)
|
||||||
|
|
||||||
|
# 控制台指令处理
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
|
if len(sys.argv) > 1 and sys.argv[1] == 'init_db':
|
||||||
init_db()
|
init_db()
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
@ -352,19 +360,27 @@ if __name__ == '__main__':
|
||||||
updater.update_all(cli=True)
|
updater.update_all(cli=True)
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
# 不知道干啥的
|
||||||
# import pkg.utils.configmgr
|
# import pkg.utils.configmgr
|
||||||
#
|
#
|
||||||
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
|
# pkg.utils.configmgr.set_config_and_reload("quote_origin", False)
|
||||||
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
|
requests.packages.urllib3.disable_warnings(InsecureRequestWarning)
|
||||||
|
|
||||||
qqbot = main(True)
|
pkg.utils.context.get_thread_ctl().submit_sys_task(
|
||||||
|
start,
|
||||||
|
True
|
||||||
|
)
|
||||||
|
|
||||||
import pkg.utils.context
|
# 主线程循环
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
time.sleep(10)
|
time.sleep(0xFF)
|
||||||
except KeyboardInterrupt:
|
except:
|
||||||
stop()
|
stop()
|
||||||
|
pkg.utils.context.get_thread_ctl().shutdown()
|
||||||
print("程序退出")
|
print("退出")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
import mirai.models.bus
|
import mirai.models.bus
|
||||||
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
|
||||||
|
@ -66,9 +65,6 @@ def random_responding():
|
||||||
class QQBotManager:
|
class QQBotManager:
|
||||||
retry = 3
|
retry = 3
|
||||||
|
|
||||||
#线程池控制
|
|
||||||
pool = None
|
|
||||||
|
|
||||||
bot: Mirai = None
|
bot: Mirai = None
|
||||||
|
|
||||||
reply_filter = None
|
reply_filter = None
|
||||||
|
@ -78,14 +74,10 @@ class QQBotManager:
|
||||||
ban_person = []
|
ban_person = []
|
||||||
ban_group = []
|
ban_group = []
|
||||||
|
|
||||||
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, pool_num: int = 10, first_time_init=True):
|
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.retry = retry
|
self.retry = retry
|
||||||
|
|
||||||
self.pool_num = pool_num
|
|
||||||
self.pool = ThreadPoolExecutor(max_workers=self.pool_num)
|
|
||||||
logging.debug("Registered thread pool Size:{}".format(pool_num))
|
|
||||||
|
|
||||||
# 加载禁用列表
|
# 加载禁用列表
|
||||||
if os.path.exists("banlist.py"):
|
if os.path.exists("banlist.py"):
|
||||||
import banlist
|
import banlist
|
||||||
|
@ -138,7 +130,10 @@ class QQBotManager:
|
||||||
|
|
||||||
self.on_person_message(event)
|
self.on_person_message(event)
|
||||||
|
|
||||||
self.go(friend_message_handler, event)
|
pkg.utils.context.get_thread_ctl().submit_user_task(
|
||||||
|
friend_message_handler,
|
||||||
|
event
|
||||||
|
)
|
||||||
|
|
||||||
@self.bot.on(StrangerMessage)
|
@self.bot.on(StrangerMessage)
|
||||||
async def on_stranger_message(event: StrangerMessage):
|
async def on_stranger_message(event: StrangerMessage):
|
||||||
|
@ -158,7 +153,10 @@ class QQBotManager:
|
||||||
|
|
||||||
self.on_person_message(event)
|
self.on_person_message(event)
|
||||||
|
|
||||||
self.go(stranger_message_handler, event)
|
pkg.utils.context.get_thread_ctl().submit_user_task(
|
||||||
|
stranger_message_handler,
|
||||||
|
event
|
||||||
|
)
|
||||||
|
|
||||||
@self.bot.on(GroupMessage)
|
@self.bot.on(GroupMessage)
|
||||||
async def on_group_message(event: GroupMessage):
|
async def on_group_message(event: GroupMessage):
|
||||||
|
@ -178,7 +176,10 @@ class QQBotManager:
|
||||||
|
|
||||||
self.on_group_message(event)
|
self.on_group_message(event)
|
||||||
|
|
||||||
self.go(group_message_handler, event)
|
pkg.utils.context.get_thread_ctl().submit_user_task(
|
||||||
|
group_message_handler,
|
||||||
|
event
|
||||||
|
)
|
||||||
|
|
||||||
def unsubscribe_all():
|
def unsubscribe_all():
|
||||||
"""取消所有订阅
|
"""取消所有订阅
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .threadctl import ThreadCtl
|
|
@ -1,50 +1,91 @@
|
||||||
|
import threading
|
||||||
|
|
||||||
context = {
|
context = {
|
||||||
'inst': {
|
'inst': {
|
||||||
'database.manager.DatabaseManager': None,
|
'database.manager.DatabaseManager': None,
|
||||||
'openai.manager.OpenAIInteract': None,
|
'openai.manager.OpenAIInteract': None,
|
||||||
'qqbot.manager.QQBotManager': None,
|
'qqbot.manager.QQBotManager': None,
|
||||||
},
|
},
|
||||||
|
'pool_ctl': None,
|
||||||
'logger_handler': None,
|
'logger_handler': None,
|
||||||
'config': None,
|
'config': None,
|
||||||
'plugin_host': None,
|
'plugin_host': None,
|
||||||
}
|
}
|
||||||
|
context_lock = threading.Lock()
|
||||||
|
|
||||||
|
### context耦合度非常高,需要大改 ###
|
||||||
def set_config(inst):
|
def set_config(inst):
|
||||||
|
context_lock.acquire()
|
||||||
context['config'] = inst
|
context['config'] = inst
|
||||||
|
context_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_config():
|
def get_config():
|
||||||
return context['config']
|
context_lock.acquire()
|
||||||
|
t = context['config']
|
||||||
|
context_lock.release()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
def set_database_manager(inst):
|
def set_database_manager(inst):
|
||||||
|
context_lock.acquire()
|
||||||
context['inst']['database.manager.DatabaseManager'] = inst
|
context['inst']['database.manager.DatabaseManager'] = inst
|
||||||
|
context_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_database_manager():
|
def get_database_manager():
|
||||||
return context['inst']['database.manager.DatabaseManager']
|
context_lock.acquire()
|
||||||
|
t = context['inst']['database.manager.DatabaseManager']
|
||||||
|
context_lock.release()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
def set_openai_manager(inst):
|
def set_openai_manager(inst):
|
||||||
|
context_lock.acquire()
|
||||||
context['inst']['openai.manager.OpenAIInteract'] = inst
|
context['inst']['openai.manager.OpenAIInteract'] = inst
|
||||||
|
context_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_openai_manager():
|
def get_openai_manager():
|
||||||
return context['inst']['openai.manager.OpenAIInteract']
|
context_lock.acquire()
|
||||||
|
t = context['inst']['openai.manager.OpenAIInteract']
|
||||||
|
context_lock.release()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
def set_qqbot_manager(inst):
|
def set_qqbot_manager(inst):
|
||||||
|
context_lock.acquire()
|
||||||
context['inst']['qqbot.manager.QQBotManager'] = inst
|
context['inst']['qqbot.manager.QQBotManager'] = inst
|
||||||
|
context_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_qqbot_manager():
|
def get_qqbot_manager():
|
||||||
return context['inst']['qqbot.manager.QQBotManager']
|
context_lock.acquire()
|
||||||
|
t = context['inst']['qqbot.manager.QQBotManager']
|
||||||
|
context_lock.release()
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
def set_plugin_host(inst):
|
def set_plugin_host(inst):
|
||||||
|
context_lock.acquire()
|
||||||
context['plugin_host'] = inst
|
context['plugin_host'] = inst
|
||||||
|
context_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_host():
|
def get_plugin_host():
|
||||||
return context['plugin_host']
|
context_lock.acquire()
|
||||||
|
t = context['plugin_host']
|
||||||
|
context_lock.release()
|
||||||
|
return t
|
||||||
|
|
||||||
|
def set_thread_ctl(inst):
|
||||||
|
context_lock.acquire()
|
||||||
|
context['pool_ctl'] = inst
|
||||||
|
context_lock.release()
|
||||||
|
|
||||||
|
from pkg.utils import ThreadCtl
|
||||||
|
def get_thread_ctl() -> ThreadCtl:
|
||||||
|
context_lock.acquire()
|
||||||
|
t = context['pool_ctl']
|
||||||
|
context_lock.release()
|
||||||
|
return t
|
|
@ -3,7 +3,7 @@ import threading
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import pkg.utils.context
|
import pkg.utils.context as context
|
||||||
import pkg.plugin.host
|
import pkg.plugin.host
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,20 +22,20 @@ def walk(module, prefix='', path_prefix=''):
|
||||||
def reload_all(notify=True):
|
def reload_all(notify=True):
|
||||||
# 解除bot的事件注册
|
# 解除bot的事件注册
|
||||||
import pkg
|
import pkg
|
||||||
pkg.utils.context.get_qqbot_manager().unsubscribe_all()
|
context.get_qqbot_manager().unsubscribe_all()
|
||||||
# 执行关闭流程
|
# 执行关闭流程
|
||||||
logging.info("执行程序关闭流程")
|
logging.info("执行程序关闭流程")
|
||||||
import main
|
import main
|
||||||
main.stop()
|
main.stop()
|
||||||
|
|
||||||
# 重载所有模块
|
# 重载所有模块
|
||||||
pkg.utils.context.context['exceeded_keys'] = pkg.utils.context.get_openai_manager().key_mgr.exceeded
|
context.context['exceeded_keys'] = context.get_openai_manager().key_mgr.exceeded
|
||||||
context = pkg.utils.context.context
|
this_context = context.context
|
||||||
walk(pkg)
|
walk(pkg)
|
||||||
importlib.reload(__import__('config'))
|
importlib.reload(__import__('config'))
|
||||||
importlib.reload(__import__('main'))
|
importlib.reload(__import__('main'))
|
||||||
importlib.reload(__import__('banlist'))
|
importlib.reload(__import__('banlist'))
|
||||||
pkg.utils.context.context = context
|
context.context = this_context
|
||||||
|
|
||||||
# 重载插件
|
# 重载插件
|
||||||
import plugins
|
import plugins
|
||||||
|
@ -43,8 +43,15 @@ def reload_all(notify=True):
|
||||||
|
|
||||||
# 执行启动流程
|
# 执行启动流程
|
||||||
logging.info("执行程序启动流程")
|
logging.info("执行程序启动流程")
|
||||||
threading.Thread(target=main.main, args=(False,), daemon=False).start()
|
context.get_thread_ctl().reload(
|
||||||
|
admin_pool_num=context.get_config().admin_pool_num,
|
||||||
|
user_pool_num=context.get_config().user_pool_num
|
||||||
|
)
|
||||||
|
context.get_thread_ctl().submit_sys_task(
|
||||||
|
main.start,
|
||||||
|
False
|
||||||
|
)
|
||||||
|
|
||||||
logging.info('程序启动完成')
|
logging.info('程序启动完成')
|
||||||
if notify:
|
if notify:
|
||||||
pkg.utils.context.get_qqbot_manager().notify_admin("重载完成")
|
context.get_qqbot_manager().notify_admin("重载完成")
|
||||||
|
|
93
pkg/utils/threadctl.py
Normal file
93
pkg/utils/threadctl.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, Future
|
||||||
|
import threading, time
|
||||||
|
|
||||||
|
class Pool():
|
||||||
|
'''
|
||||||
|
线程池结构
|
||||||
|
'''
|
||||||
|
pool_num:int = None
|
||||||
|
ctl:ThreadPoolExecutor = None
|
||||||
|
task_list:list = None
|
||||||
|
task_list_lock:threading.Lock = None
|
||||||
|
monitor_type = True
|
||||||
|
|
||||||
|
def __init__(self, pool_num):
|
||||||
|
self.pool_num = pool_num
|
||||||
|
self.ctl = ThreadPoolExecutor(max_workers = self.pool_num)
|
||||||
|
self.task_list = []
|
||||||
|
self.task_list_lock = threading.Lock()
|
||||||
|
|
||||||
|
def __thread_monitor__(self):
|
||||||
|
while self.monitor_type:
|
||||||
|
for t in self.task_list:
|
||||||
|
if not t.done():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
self.task_list.pop(self.task_list.index(t))
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
class ThreadCtl():
|
||||||
|
def __init__(self, sys_pool_num, admin_pool_num, user_pool_num):
|
||||||
|
'''
|
||||||
|
线程池控制类
|
||||||
|
sys_pool_num:分配系统使用的线程池数量(>=5)
|
||||||
|
admin_pool_num:用于处理管理员消息的线程池数量(>=1)
|
||||||
|
user_pool_num:分配用于处理用户消息的线程池的数量(>=1)
|
||||||
|
'''
|
||||||
|
if sys_pool_num < 5:
|
||||||
|
raise Exception("Too few system threads(sys_pool_num needs >= 8, but received {})".format(sys_pool_num))
|
||||||
|
if admin_pool_num < 1:
|
||||||
|
raise Exception("Too few admin threads(admin_pool_num needs >= 1, but received {})".format(admin_pool_num))
|
||||||
|
if user_pool_num < 1:
|
||||||
|
raise Exception("Too few user threads(user_pool_num needs >= 1, but received {})".format(admin_pool_num))
|
||||||
|
self.__sys_pool__ = Pool(sys_pool_num)
|
||||||
|
self.__admin_pool__ = Pool(admin_pool_num)
|
||||||
|
self.__user_pool__ = Pool(user_pool_num)
|
||||||
|
self.submit_sys_task(self.__sys_pool__.__thread_monitor__)
|
||||||
|
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
|
||||||
|
self.submit_sys_task(self.__user_pool__.__thread_monitor__)
|
||||||
|
|
||||||
|
def __submit__(self, pool:Pool, fn, /, *args, **kwargs ):
|
||||||
|
t = pool.ctl.submit(fn, *args, **kwargs)
|
||||||
|
pool.task_list_lock.acquire()
|
||||||
|
pool.task_list.append(t)
|
||||||
|
pool.task_list_lock.release()
|
||||||
|
return t
|
||||||
|
|
||||||
|
def submit_sys_task(self, fn, /, *args, **kwargs):
|
||||||
|
return self.__submit__(
|
||||||
|
self.__sys_pool__,
|
||||||
|
fn, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def submit_admin_task(self, fn, /, *args, **kwargs):
|
||||||
|
return self.__submit__(
|
||||||
|
self.__admin_pool__,
|
||||||
|
fn, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def submit_user_task(self, fn, /, *args, **kwargs):
|
||||||
|
return self.__submit__(
|
||||||
|
self.__user_pool__,
|
||||||
|
fn, *args, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.__user_pool__.ctl.shutdown(cancel_futures=True)
|
||||||
|
self.__user_pool__.monitor_type = False
|
||||||
|
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
|
||||||
|
self.__admin_pool__.monitor_type = False
|
||||||
|
self.__sys_pool__.monitor_type = False
|
||||||
|
self.__sys_pool__.ctl.shutdown(wait=True, cancel_futures=False)
|
||||||
|
|
||||||
|
def reload(self, admin_pool_num, user_pool_num):
|
||||||
|
self.__user_pool__.ctl.shutdown(cancel_futures=True)
|
||||||
|
self.__user_pool__.monitor_type = False
|
||||||
|
self.__admin_pool__.ctl.shutdown(cancel_futures=True)
|
||||||
|
self.__admin_pool__.monitor_type = False
|
||||||
|
self.__admin_pool__ = Pool(admin_pool_num)
|
||||||
|
self.__user_pool__ = Pool(user_pool_num)
|
||||||
|
self.submit_sys_task(self.__admin_pool__.__thread_monitor__)
|
||||||
|
self.submit_sys_task(self.__user_pool__.__thread_monitor__)
|
Loading…
Reference in New Issue
Block a user