refactor: 完成MessageSource适配器解耦

This commit is contained in:
Rock Chin 2023-04-21 17:51:58 +08:00
parent 016391c976
commit 160086feb9
7 changed files with 185 additions and 150 deletions

View File

@ -1,6 +1,12 @@
# 配置文件: 注释里标[必需]的参数必须修改, 其他参数根据需要修改, 但请勿删除 # 配置文件: 注释里标[必需]的参数必须修改, 其他参数根据需要修改, 但请勿删除
import logging import logging
# 消息处理协议适配器
# 目前支持以下适配器:
# - "yirimirai": YiriMirai框架适配器, 请填写mirai_http_api_config
# - "nonebot2": NoneBot2框架适配器, 请填写nonebot2_config
msg_source_adapter = "yirimirai"
# [必需] Mirai的配置 # [必需] Mirai的配置
# 请到配置mirai的步骤中的教程查看每个字段的信息 # 请到配置mirai的步骤中的教程查看每个字段的信息
# adapter: 选择适配器目前支持HTTPAdapter和WebSocketAdapter # adapter: 选择适配器目前支持HTTPAdapter和WebSocketAdapter
@ -18,6 +24,9 @@ mirai_http_api_config = {
"qq": 1234567890 "qq": 1234567890
} }
# NoneBot2的配置
nonebot2_config = {}
# [必需] OpenAI的配置 # [必需] OpenAI的配置
# api_key: OpenAI的API Key # api_key: OpenAI的API Key
# http_proxy: 请求OpenAI时使用的代理None为不使用https和socks5暂不能使用 # http_proxy: 请求OpenAI时使用的代理None为不使用https和socks5暂不能使用

View File

@ -208,7 +208,7 @@ def start(first_time_init=False):
def run_bot_wrapper(): def run_bot_wrapper():
global known_exception_caught global known_exception_caught
try: try:
qqbot.bot.run() qqbot.adapter.run_sync()
except TypeError as e: except TypeError as e:
if str(e).__contains__("argument 'debug'"): if str(e).__contains__("argument 'debug'"):
logging.error( logging.error(

View File

@ -11,6 +11,7 @@ import traceback
import pkg.utils.context as context import pkg.utils.context as context
import pkg.plugin.switch as switch import pkg.plugin.switch as switch
import pkg.plugin.settings as settings import pkg.plugin.settings as settings
import pkg.qqbot.adapter as msadapter
from mirai import Mirai from mirai import Mirai
@ -276,6 +277,10 @@ class PluginHost:
"""获取机器人对象""" """获取机器人对象"""
return context.get_qqbot_manager().bot return context.get_qqbot_manager().bot
def get_bot_adapter(self) -> msadapter.MessageSourceAdapter:
"""获取消息源适配器"""
return context.get_qqbot_manager().adapter
def send_person_message(self, person, message): def send_person_message(self, person, message):
"""发送私聊消息""" """发送私聊消息"""
asyncio.run(self.get_bot().send_friend_message(person, message)) asyncio.run(self.get_bot().send_friend_message(person, message))

View File

@ -3,6 +3,7 @@ import typing
import mirai import mirai
class MessageSourceAdapter: class MessageSourceAdapter:
def __init__(self, config: dict): def __init__(self, config: dict):
pass pass
@ -25,16 +26,22 @@ class MessageSourceAdapter:
def reply_message( def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: mirai.MessageEvent,
message: mirai.MessageChain message: mirai.MessageChain,
quote_origin: bool = False
): ):
"""回复消息 """回复消息
Args: Args:
message_source (mirai.MessageEvent): YiriMirai消息源事件 message_source (mirai.MessageEvent): YiriMirai消息源事件
message (mirai.MessageChain): YiriMirai库的消息链 message (mirai.MessageChain): YiriMirai库的消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
""" """
raise NotImplementedError raise NotImplementedError
def is_muted(self, group_id: int) -> bool:
"""获取账号是否在指定群被禁言"""
raise NotImplementedError
def register_listener( def register_listener(
self, self,
event_type: typing.Type[mirai.Event], event_type: typing.Type[mirai.Event],

View File

@ -3,9 +3,9 @@ import json
import os import os
import threading import threading
import mirai.models.bus
from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \ from mirai import At, GroupMessage, MessageEvent, Mirai, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
FriendMessage, Image FriendMessage, Image, MessageChain, Plain
from func_timeout import func_set_timeout from func_timeout import func_set_timeout
import pkg.openai.session import pkg.openai.session
@ -21,6 +21,8 @@ import pkg.plugin.host as plugin_host
import pkg.plugin.models as plugin_models import pkg.plugin.models as plugin_models
import tips as tips_custom import tips as tips_custom
import pkg.qqbot.adapter as msadapter
# 检查消息是否符合泛响应匹配机制 # 检查消息是否符合泛响应匹配机制
def check_response_rule(text: str): def check_response_rule(text: str):
@ -64,7 +66,9 @@ def random_responding():
class QQBotManager: class QQBotManager:
retry = 3 retry = 3
bot: Mirai = None adapter: msadapter.MessageSourceAdapter = None
bot_account_id: int = 0
reply_filter = None reply_filter = None
@ -80,6 +84,119 @@ class QQBotManager:
self.timeout = config.process_message_timeout self.timeout = config.process_message_timeout
self.retry = config.retry_times self.retry = config.retry_times
# 由于YiriMirai的bot对象是单例的且shutdown方法暂时无法使用
# 故只在第一次初始化时创建bot对象重载之后使用原bot对象
# 因此bot的配置不支持热重载
if first_time_init:
if config.msg_source_adapter == 'yirimirai':
from pkg.qqbot.sources.yirimirai import YiriMiraiAdapter
self.bot_account_id = config.mirai_http_api_config['qq']
self.adapter = YiriMiraiAdapter(mirai_http_api_config)
elif config.msg_source_adapter == 'nonebot2':
pass
else:
self.adapter = pkg.utils.context.get_qqbot_manager().adapter
pkg.utils.context.set_qqbot_manager(self)
# 注册诸事件
# Caution: 注册新的事件处理器之后请务必在unsubscribe_all中编写相应的取消订阅代码
def on_friend_message(event: FriendMessage):
def friend_message_handler():
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_person_message(event)
pkg.utils.context.get_thread_ctl().submit_user_task(
friend_message_handler,
)
self.adapter.register_listener(
FriendMessage,
on_friend_message
)
def on_stranger_message(event: StrangerMessage):
def stranger_message_handler():
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_person_message(event)
pkg.utils.context.get_thread_ctl().submit_user_task(
stranger_message_handler,
)
self.adapter.register_listener(
StrangerMessage,
on_stranger_message
)
def on_group_message(event: GroupMessage):
def group_message_handler(event: GroupMessage):
# 触发事件
args = {
"launcher_type": "group",
"launcher_id": event.group.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_group_message(event)
pkg.utils.context.get_thread_ctl().submit_user_task(
group_message_handler,
event
)
self.adapter.register_listener(
GroupMessage,
on_group_message
)
def unsubscribe_all():
"""取消所有订阅
用于在热重载流程中卸载所有事件处理器
"""
self.adapter.unregister_listener(
FriendMessage,
on_friend_message
)
self.adapter.unregister_listener(
StrangerMessage,
on_stranger_message
)
self.adapter.unregister_listener(
GroupMessage,
on_group_message
)
self.unsubscribe_all = unsubscribe_all
# 加载禁用列表 # 加载禁用列表
if os.path.exists("banlist.py"): if os.path.exists("banlist.py"):
import banlist import banlist
@ -102,139 +219,20 @@ class QQBotManager:
else: else:
self.reply_filter = pkg.qqbot.filter.ReplyFilter([]) self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
# 由于YiriMirai的bot对象是单例的且shutdown方法暂时无法使用
# 故只在第一次初始化时创建bot对象重载之后使用原bot对象
# 因此bot的配置不支持热重载
if first_time_init:
self.first_time_init(mirai_http_api_config)
else:
self.bot = pkg.utils.context.get_qqbot_manager().bot
pkg.utils.context.set_qqbot_manager(self)
# Caution: 注册新的事件处理器之后请务必在unsubscribe_all中编写相应的取消订阅代码
@self.bot.on(FriendMessage)
async def on_friend_message(event: FriendMessage):
def friend_message_handler(event: FriendMessage):
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_person_message(event)
pkg.utils.context.get_thread_ctl().submit_user_task(
friend_message_handler,
event
)
@self.bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage):
def stranger_message_handler(event: StrangerMessage):
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_person_message(event)
pkg.utils.context.get_thread_ctl().submit_user_task(
stranger_message_handler,
event
)
@self.bot.on(GroupMessage)
async def on_group_message(event: GroupMessage):
def group_message_handler(event: GroupMessage):
# 触发事件
args = {
"launcher_type": "group",
"launcher_id": event.group.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args)
if plugin_event.is_prevented_default():
return
self.on_group_message(event)
pkg.utils.context.get_thread_ctl().submit_user_task(
group_message_handler,
event
)
def unsubscribe_all():
"""取消所有订阅
用于在热重载流程中卸载所有事件处理器
"""
assert isinstance(self.bot, Mirai)
bus = self.bot.bus
assert isinstance(bus, mirai.models.bus.ModelEventBus)
bus.unsubscribe(FriendMessage, on_friend_message)
bus.unsubscribe(StrangerMessage, on_stranger_message)
bus.unsubscribe(GroupMessage, on_group_message)
self.unsubscribe_all = unsubscribe_all
def first_time_init(self, mirai_http_api_config: dict):
"""热重载后不再运行此函数"""
if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter":
bot = Mirai(
qq=mirai_http_api_config['qq'],
adapter=WebSocketAdapter(
verify_key=mirai_http_api_config['verifyKey'],
host=mirai_http_api_config['host'],
port=mirai_http_api_config['port']
)
)
elif mirai_http_api_config['adapter'] == "HTTPAdapter":
bot = Mirai(
qq=mirai_http_api_config['qq'],
adapter=HTTPAdapter(
verify_key=mirai_http_api_config['verifyKey'],
host=mirai_http_api_config['host'],
port=mirai_http_api_config['port']
)
)
else:
raise Exception("未知的适配器类型")
self.bot = bot
def send(self, event, msg, check_quote=True): def send(self, event, msg, check_quote=True):
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
asyncio.run( self.adapter.reply_message(
self.bot.send(event, msg, quote=True if config.quote_origin and check_quote else False)) event,
msg,
quote_origin=True if config.quote_origin and check_quote else False
)
# 私聊消息处理 # 私聊消息处理
def on_person_message(self, event: MessageEvent): def on_person_message(self, event: MessageEvent):
import config import config
reply = '' reply = ''
if event.sender.id == self.bot.qq: if event.sender.id == self.bot_account_id:
pass pass
else: else:
if Image in event.message_chain: if Image in event.message_chain:
@ -277,8 +275,8 @@ class QQBotManager:
def process(text=None) -> str: def process(text=None) -> str:
replys = "" replys = ""
if At(self.bot.qq) in event.message_chain: if At(self.bot_account_id) in event.message_chain:
event.message_chain.remove(At(self.bot.qq)) event.message_chain.remove(At(self.bot_account_id))
# 超时则重试,重试超过次数则放弃 # 超时则重试,重试超过次数则放弃
failed = 0 failed = 0
@ -312,7 +310,7 @@ class QQBotManager:
if Image in event.message_chain: if Image in event.message_chain:
pass pass
else: else:
if At(self.bot.qq) in event.message_chain and response_at(): if At(self.bot_account_id) in event.message_chain and response_at():
# 直接调用 # 直接调用
reply = process() reply = process()
else: else:
@ -334,22 +332,33 @@ class QQBotManager:
if config.admin_qq != 0 and config.admin_qq != []: if config.admin_qq != 0 and config.admin_qq != []:
logging.info("通知管理员:{}".format(message)) logging.info("通知管理员:{}".format(message))
if type(config.admin_qq) == int: if type(config.admin_qq) == int:
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message)) self.adapter.send_message(
threading.Thread(target=asyncio.run, args=(send_task,)).start() "person",
config.admin_qq,
MessageChain([Plain("[bot]{}".format(message))])
)
else: else:
for adm in config.admin_qq: for adm in config.admin_qq:
send_task = self.bot.send_friend_message(adm, "[bot]{}".format(message)) self.adapter.send_message(
threading.Thread(target=asyncio.run, args=(send_task,)).start() "person",
adm,
MessageChain([Plain("[bot]{}".format(message))])
)
def notify_admin_message_chain(self, message): def notify_admin_message_chain(self, message):
config = pkg.utils.context.get_config() config = pkg.utils.context.get_config()
if config.admin_qq != 0 and config.admin_qq != []: if config.admin_qq != 0 and config.admin_qq != []:
logging.info("通知管理员:{}".format(message)) logging.info("通知管理员:{}".format(message))
if type(config.admin_qq) == int: if type(config.admin_qq) == int:
send_task = self.bot.send_friend_message(config.admin_qq, message) self.adapter.send_message(
threading.Thread(target=asyncio.run, args=(send_task,)).start() "person",
config.admin_qq,
message
)
else: else:
for adm in config.admin_qq: for adm in config.admin_qq:
send_task = self.bot.send_friend_message(adm, message) self.adapter.send_message(
threading.Thread(target=asyncio.run, args=(send_task,)).start() "person",
adm,
message
)

View File

@ -66,11 +66,8 @@ def process_message(launcher_type: str, launcher_id: int, text_message: str, mes
# 检查是否被禁言 # 检查是否被禁言
if launcher_type == 'group': if launcher_type == 'group':
result = mgr.bot.member_info(target=launcher_id, member_id=mgr.bot.qq).get() if mgr.adapter.is_muted(launcher_id):
result = asyncio.run(result) logging.info("机器人被禁言,跳过消息处理(group_{})".format(launcher_id))
if result.mute_time_remaining > 0:
logging.info("机器人被禁言,跳过消息处理(group_{},剩余{}s)".format(launcher_id,
result.mute_time_remaining))
return reply return reply
import config import config

View File

@ -60,7 +60,8 @@ class YiriMiraiAdapter(MessageSourceAdapter):
def reply_message( def reply_message(
self, self,
message_source: mirai.MessageEvent, message_source: mirai.MessageEvent,
message: mirai.MessageChain message: mirai.MessageChain,
quote_origin: bool = False
): ):
"""回复消息 """回复消息
@ -68,7 +69,14 @@ class YiriMiraiAdapter(MessageSourceAdapter):
message_source (mirai.MessageEvent): YiriMirai消息源事件 message_source (mirai.MessageEvent): YiriMirai消息源事件
message (mirai.MessageChain): YiriMirai库的消息链 message (mirai.MessageChain): YiriMirai库的消息链
""" """
asyncio.run(self.bot.send(message_source, message)) asyncio.run(self.bot.send(message_source, message, quote_origin))
def is_muted(self, group_id: int) -> bool:
result = self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
result = asyncio.run(result)
if result.mute_time_remaining > 0:
return True
return False
def register_listener( def register_listener(
self, self,