QChatGPT/pkg/qqbot/manager.py
2023-01-02 11:33:26 +08:00

221 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import os
import threading
import mirai.models.bus
import openai.error
from mirai import At, GroupMessage, MessageEvent, Mirai, Plain, StrangerMessage, WebSocketAdapter, HTTPAdapter, \
FriendMessage, Image
from mirai.models.bus import ModelEventBus
from mirai.models.message import Quote
import config
import pkg.openai.session
import pkg.openai.manager
from func_timeout import FunctionTimedOut
import logging
import pkg.qqbot.filter
import pkg.qqbot.process as processor
import pkg.utils.context
# 并行运行
def go(func, args=()):
thread = threading.Thread(target=func, args=args, daemon=True)
thread.start()
# 检查消息是否符合泛响应匹配机制
def check_response_rule(text: str) -> (bool, str):
if not hasattr(config, 'response_rules'):
return False, ''
rules = config.response_rules
# 检查前缀匹配
if 'prefix' in rules:
for rule in rules['prefix']:
if text.startswith(rule):
return True, text.replace(rule, "", 1)
# 检查正则表达式匹配
if 'regexp' in rules:
for rule in rules['regexp']:
import re
match = re.match(rule, text)
if match:
return True, text
return False, ""
# 控制QQ消息输入输出的类
class QQBotManager:
retry = 3
bot = None
reply_filter = None
def __init__(self, mirai_http_api_config: dict, timeout: int = 60, retry: int = 3, first_time_init=True):
self.timeout = timeout
self.retry = retry
if os.path.exists("sensitive.json") \
and config.sensitive_word_filter is not None \
and config.sensitive_word_filter:
with open("sensitive.json", "r", encoding="utf-8") as f:
self.reply_filter = pkg.qqbot.filter.ReplyFilter(json.load(f)['words'])
else:
self.reply_filter = pkg.qqbot.filter.ReplyFilter([])
if first_time_init:
self.first_time_init(mirai_http_api_config)
else:
self.bot = pkg.utils.context.get_qqbot_manager().bot
pkg.utils.context.set_qqbot_manager(self)
# Caution: 注册新的事件处理器之后请务必在unsubscribe_all中编写相应的取消订阅代码
@self.bot.on(FriendMessage)
async def on_friend_message(event: FriendMessage):
go(self.on_person_message, (event,))
@self.bot.on(StrangerMessage)
async def on_stranger_message(event: StrangerMessage):
go(self.on_person_message, (event,))
@self.bot.on(GroupMessage)
async def on_group_message(event: GroupMessage):
go(self.on_group_message, (event,))
def unsubscribe_all():
assert isinstance(self.bot, Mirai)
bus = self.bot.bus
assert isinstance(bus, mirai.models.bus.ModelEventBus)
bus.unsubscribe(FriendMessage, on_friend_message)
bus.unsubscribe(StrangerMessage, on_stranger_message)
bus.unsubscribe(GroupMessage, on_group_message)
self.unsubscribe_all = unsubscribe_all
def first_time_init(self, mirai_http_api_config: dict):
"""热重载后不再运行此函数"""
if 'adapter' not in mirai_http_api_config or mirai_http_api_config['adapter'] == "WebSocketAdapter":
bot = Mirai(
qq=mirai_http_api_config['qq'],
adapter=WebSocketAdapter(
verify_key=mirai_http_api_config['verifyKey'],
host=mirai_http_api_config['host'],
port=mirai_http_api_config['port']
)
)
elif mirai_http_api_config['adapter'] == "HTTPAdapter":
bot = Mirai(
qq=mirai_http_api_config['qq'],
adapter=HTTPAdapter(
verify_key=mirai_http_api_config['verifyKey'],
host=mirai_http_api_config['host'],
port=mirai_http_api_config['port']
)
)
else:
raise Exception("未知的适配器类型")
self.bot = bot
def send(self, event, msg, check_quote=True):
asyncio.run(
self.bot.send(event, msg, quote=True if hasattr(config,
"quote_origin") and config.quote_origin and check_quote else False))
# 私聊消息处理
def on_person_message(self, event: MessageEvent):
reply = ''
if event.sender.id == self.bot.qq:
pass
else:
if Image in event.message_chain:
pass
else:
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
reply = processor.process_message('person', event.sender.id, str(event.message_chain),
event.message_chain,
event.sender.id)
break
except FunctionTimedOut:
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
failed += 1
continue
if failed == self.retry:
pkg.openai.session.get_session('person_{}'.format(event.sender.id)).release_response_lock()
self.notify_admin("{} 请求超时".format("person_{}".format(event.sender.id)))
reply = ["[bot]err:请求超时"]
if reply:
return self.send(event, reply, check_quote=False)
# 群消息处理
def on_group_message(self, event: GroupMessage):
reply = ''
def process(text=None) -> str:
replys = ""
if At(self.bot.qq) in event.message_chain:
event.message_chain.remove(At(self.bot.qq))
# 超时则重试,重试超过次数则放弃
failed = 0
for i in range(self.retry):
try:
replys = processor.process_message('group', event.group.id,
str(event.message_chain).strip() if text is None else text,
event.message_chain,
event.sender.id)
break
except FunctionTimedOut:
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
failed += 1
continue
if failed == self.retry:
pkg.openai.session.get_session('group_{}'.format(event.group.id)).release_response_lock()
self.notify_admin("{} 请求超时".format("group_{}".format(event.group.id)))
replys = ["[bot]err:请求超时"]
return replys
if Image in event.message_chain:
pass
elif At(self.bot.qq) not in event.message_chain:
check, result = check_response_rule(str(event.message_chain).strip())
if check:
reply = process(result.strip())
else:
# 直接调用
reply = process()
if reply:
return self.send(event, reply)
# 通知系统管理员
def notify_admin(self, message: str):
if hasattr(config, "admin_qq") and config.admin_qq != 0:
logging.info("通知管理员:{}".format(message))
send_task = self.bot.send_friend_message(config.admin_qq, "[bot]{}".format(message))
threading.Thread(target=asyncio.run, args=(send_task,)).start()