refactor: 请求处理控制流基础架构

This commit is contained in:
RockChinQ 2024-01-26 15:51:49 +08:00
parent a064c24f60
commit 8d084427d2
55 changed files with 1430 additions and 146 deletions

View File

@ -1,54 +0,0 @@
import logging
import os
import sys
import time
import colorlog
log_colors_config = {
'DEBUG': 'green', # cyan white
'INFO': 'white',
'WARNING': 'yellow',
'ERROR': 'red',
'CRITICAL': 'cyan',
}
async def init_logging() -> logging.Logger:
level = logging.INFO
if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
level = logging.DEBUG
log_file_name = "logs/qcg-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
qcg_logger = logging.getLogger("qcg")
qcg_logger.setLevel(level)
log_handlers: logging.Handler = [
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_file_name)
]
for handler in log_handlers:
handler.setLevel(level)
handler.setFormatter(
colorlog.ColoredFormatter(
fmt="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors=log_colors_config
)
)
qcg_logger.addHandler(handler)
logging.basicConfig(level=level, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式
)
return qcg_logger

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
import asyncio
from ..qqbot import manager as qqbot_mgr
from ..openai import manager as openai_mgr
@ -8,6 +9,8 @@ from ..config import manager as config_mgr
from ..database import manager as database_mgr
from ..utils.center import v2 as center_mgr
from ..plugin import host as plugin_host
from . import pool, controller
from ..pipeline import stagemgr
class Application:
@ -23,16 +26,24 @@ class Application:
ctr_mgr: center_mgr.V2CenterAPI = None
query_pool: pool.QueryPool = None
ctrl: controller.Controller = None
stage_mgr: stagemgr.StageManager = None
logger: logging.Logger = None
def __init__(self):
pass
async def initialize(self):
await self.im_mgr.initialize()
async def run(self):
# TODO make it async
plugin_host.initialize_plugins()
await self.im_mgr.run()
tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.ctrl.run())
]
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

View File

@ -3,12 +3,15 @@ from __future__ import print_function
import os
import sys
from . import files
from . import deps
from . import log
from . import config
from .bootutils import files
from .bootutils import deps
from .bootutils import log
from .bootutils import config
from . import app
from . import pool
from . import controller
from ..pipeline import stagemgr
from ..audit import identifier
from ..database import manager as db_mgr
from ..openai import manager as llm_mgr
@ -86,6 +89,8 @@ async def make_app() -> app.Application:
ap.cfg_mgr = cfg_mgr
ap.tips_mgr = tips_mgr
ap.query_pool = pool.QueryPool()
center_v2_api = center_v2.V2CenterAPI(
basic_info={
"host_id": identifier.identifier['host_id'],
@ -111,8 +116,16 @@ async def make_app() -> app.Application:
llm_session.load_sessions()
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()
ap.stage_mgr = stage_mgr
ctrl = controller.Controller(ap)
ap.ctrl = ctrl
# TODO make it async
plugin_host.load_plugins()
# plugin_host.initialize_plugins()
@ -122,5 +135,4 @@ async def make_app() -> app.Application:
async def main():
app_inst = await make_app()
await app_inst.initialize()
await app_inst.run()

View File

@ -1,7 +1,7 @@
import json
from ..config import manager as config_mgr
from ..config.impls import pymodule
from ...config import manager as config_mgr
from ...config.impls import pymodule
load_python_module_config = config_mgr.load_python_module_config

56
pkg/core/bootutils/log.py Normal file
View File

@ -0,0 +1,56 @@
import logging
import os
import sys
import time
import colorlog
log_colors_config = {
"DEBUG": "green", # cyan white
"INFO": "white",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "cyan",
}
async def init_logging() -> logging.Logger:
level = logging.INFO
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
level = logging.DEBUG
log_file_name = "logs/qcg-%s.log" % time.strftime(
"%Y-%m-%d-%H-%M-%S", time.localtime()
)
qcg_logger = logging.getLogger("qcg")
qcg_logger.setLevel(level)
color_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
log_colors=log_colors_config,
)
stream_handler = logging.StreamHandler(sys.stdout)
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)]
for handler in log_handlers:
handler.setLevel(level)
handler.setFormatter(color_formatter)
qcg_logger.addHandler(handler)
logging.basicConfig(
level=logging.INFO, # 设置日志输出格式
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
# 日志输出的格式
# -8表示占位符让输出左对齐输出长度都为8位
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
handlers=[logging.NullHandler()],
)
return qcg_logger

View File

84
pkg/core/controller.py Normal file
View File

@ -0,0 +1,84 @@
from __future__ import annotations
import asyncio
import traceback
from . import app, entities
from ..pipeline import entities as pipeline_entities
DEFAULT_QUERY_CONCURRENCY = 10
class Controller:
"""总控制器
"""
ap: app.Application
semaphore: asyncio.Semaphore = None
"""请求并发控制信号量"""
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
async def consumer(self):
"""事件处理循环
"""
while True:
selected_query: entities.Query = None
# 取请求
async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries
if queries:
selected_query = queries.pop(0) # FCFS
else:
await self.ap.query_pool.condition.wait()
continue
if selected_query:
async def _process_query(selected_query):
async with self.semaphore:
await self.process_query(selected_query)
asyncio.create_task(_process_query(selected_query))
async def process_query(self, query: entities.Query):
"""处理请求
"""
self.ap.logger.debug(f"Processing query {query}")
try:
for stage_container in self.ap.stage_mgr.stage_containers:
res = await stage_container.inst.process(query, stage_container.inst_name)
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
if res.user_notice:
await self.ap.im_mgr.send(
query.message_event,
res.user_notice
)
if res.debug_notice:
self.ap.logger.debug(res.debug_notice)
if res.console_notice:
self.ap.logger.info(res.console_notice)
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
break
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
query = res.new_query
continue
except Exception as e:
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")
async def run(self):
"""运行控制器
"""
await self.consumer()

41
pkg/core/entities.py Normal file
View File

@ -0,0 +1,41 @@
from __future__ import annotations
import enum
import typing
import pydantic
import mirai
class LauncherTypes(enum.Enum):
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID"""
launcher_type: LauncherTypes
"""会话类型"""
launcher_id: int
"""会话ID"""
sender_id: int
"""发送者ID"""
message_event: mirai.MessageEvent
"""事件"""
message_chain: mirai.MessageChain
"""消息链"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链"""

52
pkg/core/pool.py Normal file
View File

@ -0,0 +1,52 @@
from __future__ import annotations
import asyncio
import mirai
from . import entities
class QueryPool:
query_id_counter: int = 0
pool_lock: asyncio.Lock
queries: list[entities.Query]
condition: asyncio.Condition
def __init__(self):
self.query_id_counter = 0
self.pool_lock = asyncio.Lock()
self.queries = []
self.condition = asyncio.Condition(self.pool_lock)
async def add_query(
self,
launcher_type: entities.LauncherTypes,
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain
) -> entities.Query:
async with self.condition:
query = entities.Query(
query_id=self.query_id_counter,
launcher_type=launcher_type,
launcher_id=launcher_id,
sender_id=sender_id,
message_event=message_event,
message_chain=message_chain
)
self.queries.append(query)
self.query_id_counter += 1
self.condition.notify_all()
async def __aenter__(self):
await self.pool_lock.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
self.pool_lock.release()

View File

@ -10,7 +10,7 @@ from ..utils import context
from ..audit import gatherer
from ..openai import modelmgr
from ..openai.api import model as api_model
from ..boot import app
from ..core import app
class OpenAIInteract:

0
pkg/pipeline/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,76 @@
from __future__ import annotations
import re
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class('BanSessionCheckStage')
class BanSessionCheckStage(stage.PipelineStage):
banlist_mgr: cfg_mgr.ConfigManager
async def initialize(self):
self.banlist_mgr = await cfg_mgr.load_python_module_config(
"banlist.py",
"res/templates/banlist-template.py"
)
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
) -> entities.StageProcessResult:
if not self.banlist_mgr.data['enable']:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
result = False
if query.launcher_type == 'group':
if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应
result = True
# 检查是否显式声明发起人QQ要被person忽略
elif query.sender_id in self.banlist_mgr.data['person']:
result = True
else:
for group_rule in self.banlist_mgr.data['group']:
if type(group_rule) == int:
if group_rule == query.launcher_id:
result = True
elif type(group_rule) == str:
if group_rule.startswith('!'):
reg_str = group_rule[1:]
if re.match(reg_str, str(query.launcher_id)):
result = False
break
else:
if re.match(group_rule, str(query.launcher_id)):
result = True
elif query.launcher_type == 'person':
if not self.banlist_mgr.data['enable_private']:
result = True
else:
for person_rule in self.banlist_mgr.data['person']:
if type(person_rule) == int:
if person_rule == query.launcher_id:
result = True
elif type(person_rule) == str:
if person_rule.startswith('!'):
reg_str = person_rule[1:]
if re.match(reg_str, str(query.launcher_id)):
result = False
break
else:
if re.match(person_rule, str(query.launcher_id)):
result = True
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
new_query=query,
debug_notice=f'根据禁用列表忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
)

View File

View File

@ -0,0 +1,128 @@
from __future__ import annotations
import mirai
from ...core import app
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from . import filter, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
@stage.stage_class('PostContentFilterStage')
@stage.stage_class('PreContentFilterStage')
class ContentFilterStage(stage.PipelineStage):
filter_chain: list[filter.ContentFilter]
def __init__(self, ap: app.Application):
self.filter_chain = []
super().__init__(ap)
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
if self.ap.cfg_mgr.data['sensitive_word_filter']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
if self.ap.cfg_mgr.data['baidu_check']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
for filter in self.filter_chain:
await filter.initialize()
async def _pre_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm前处理消息
只要有一个不通过就不放行只放行 PASS 的消息
"""
if not self.ap.cfg_mgr.data['income_msg_check']:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages:
result = await filter.process(message)
if result.level in [
filter_entities.ResultLevel.BLOCK,
filter_entities.ResultLevel.MASKED
]:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def _post_process(
self,
message: str,
query: core_entities.Query,
) -> entities.StageProcessResult:
"""请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
"""
for filter in self.filter_chain:
if filter_entities.EnableStage.POST in filter.enable_stages:
result = await filter.process(message)
if result.level == filter_entities.ResultLevel.BLOCK:
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=result.user_notice,
console_notice=result.console_notice
)
elif result.level in [
filter_entities.ResultLevel.PASS,
filter_entities.ResultLevel.MASKED
]:
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
)
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
async def process(
self,
query: core_entities.Query,
stage_inst_name: str
) -> entities.StageProcessResult:
"""处理
"""
if stage_inst_name == 'PreContentFilterStage':
return await self._pre_process(
str(query.message_chain).strip(),
query
)
elif stage_inst_name == 'PostContentFilterStage':
return await self._post_process(
str(query.message_chain).strip(),
query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@ -0,0 +1,64 @@
import typing
import enum
import pydantic
class ResultLevel(enum.Enum):
"""结果等级"""
PASS = enum.auto()
"""通过"""
WARN = enum.auto()
"""警告"""
MASKED = enum.auto()
"""已掩去"""
BLOCK = enum.auto()
"""阻止"""
class EnableStage(enum.Enum):
"""启用阶段"""
PRE = enum.auto()
"""预处理"""
POST = enum.auto()
"""后处理"""
class FilterResult(pydantic.BaseModel):
level: ResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""不通过时,用户提示消息"""
console_notice: str
"""不通过时,控制台提示消息"""
class ManagerResultLevel(enum.Enum):
"""处理器结果等级"""
CONTINUE = enum.auto()
"""继续"""
INTERRUPT = enum.auto()
"""中断"""
class FilterManagerResult(pydantic.BaseModel):
level: ManagerResultLevel
replacement: str
"""替换后的消息"""
user_notice: str
"""用户提示消息"""
console_notice: str
"""控制台提示消息"""

View File

@ -0,0 +1,34 @@
# 内容过滤器的抽象类
from __future__ import annotations
import abc
from ...core import app
from . import entities
class ContentFilter(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
@property
def enable_stages(self):
"""启用的阶段
"""
return [
entities.EnableStage.PRE,
entities.EnableStage.POST
]
async def initialize(self):
"""初始化过滤器
"""
pass
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
"""处理消息
"""
raise NotImplementedError

View File

@ -0,0 +1,61 @@
from __future__ import annotations
import aiohttp
from .. import entities
from .. import filter as filter_model
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
class BaiduCloudExamine(filter_model.ContentFilter):
"""百度云内容审核"""
async def _get_token(self) -> str:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.cfg_mgr.data['baidu_api_key'],
"client_secret": self.ap.cfg_mgr.data['baidu_secret_key']
}
) as resp:
return (await resp.json())['access_token']
async def process(self, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session:
async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()),
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
data=f"text={message}".encode('utf-8')
) as resp:
result = await resp.json()
if "error_code" in result:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice='',
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
)
else:
conclusion = result["conclusion"]
if conclusion in ("合规"):
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=f"百度云判定结果:{conclusion}"
)
else:
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'],
console_notice=f"百度云判定结果:{conclusion}"
)

View File

@ -0,0 +1,44 @@
from __future__ import annotations
import re
from .. import filter as filter_model
from .. import entities
from ....config import manager as cfg_mgr
class BanWordFilter(filter_model.ContentFilter):
"""根据内容禁言"""
sensitive: cfg_mgr.ConfigManager
async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config(
"sensitive.json",
"res/templates/sensitive-template.json"
)
async def process(self, message: str) -> entities.FilterResult:
found = False
for word in self.sensitive.data['words']:
match = re.findall(word, message)
if len(match) > 0:
found = True
for i in range(len(match)):
if self.sensitive.data['mask_word'] == "":
message = message.replace(
match[i], self.sensitive.data['mask'] * len(match[i])
)
else:
message = message.replace(
match[i], self.sensitive.data['mask_word']
)
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '',
console_notice=''
)

View File

@ -0,0 +1,43 @@
from __future__ import annotations
import re
from .. import entities
from .. import filter as filter_model
class ContentIgnore(filter_model.ContentFilter):
"""根据内容忽略消息"""
@property
def enable_stages(self):
return [
entities.EnableStage.PRE,
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement='',
user_notice='',
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
)
return entities.FilterResult(
level=entities.ResultLevel.PASS,
replacement=message,
user_notice='',
console_notice=''
)

38
pkg/pipeline/entities.py Normal file
View File

@ -0,0 +1,38 @@
from __future__ import annotations
import enum
import typing
import pydantic
import mirai
import mirai.models.message as mirai_message
from ..core import entities
class ResultType(enum.Enum):
CONTINUE = enum.auto()
"""继续流水线"""
INTERRUPT = enum.auto()
"""中断流水线"""
class StageProcessResult(pydantic.BaseModel):
result_type: ResultType
new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给用户"""
admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
"""只要设置了就会发送给管理员"""
console_notice: typing.Optional[str] = ''
"""只要设置了就会输出到控制台"""
debug_notice: typing.Optional[str] = ''

View File

View File

@ -0,0 +1,57 @@
from __future__ import annotations
import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain, MessageChain
from ...core import app
from . import strategy
from .strategies import image, forward
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("LongTextProcessStage")
class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
async def initialize(self):
config = self.ap.cfg_mgr.data
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
use_font = config['font_path']
try:
# 检查是否存在
if not os.path.exists(use_font):
# 若是windows系统使用微软雅黑
if os.name == "nt":
use_font = "C:/Windows/Fonts/msyh.ttc"
if not os.path.exists(use_font):
self.ap.logger.warn("未找到字体文件且无法使用Windows自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
config['blob_message_strategy'] = "forward"
else:
self.ap.logger.info("使用Windows自带字体" + use_font)
self.ap.cfg_mgr.data['font_path'] = use_font
else:
self.ap.logger.warn("未找到字体文件且无法使用系统自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']:
query.message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)

View File

@ -0,0 +1,62 @@
# 转发消息组件
from __future__ import annotations
import typing
from mirai.models import MessageChain
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
from .. import strategy as strategy_model
class ForwardMessageDiaplay(MiraiBaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
preview: typing.List[str] = []
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str) -> list[MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
source="聊天记录",
preview=["QQ用户: "+message],
summary="查看1条转发消息"
)
node_list = [
ForwardMessageNode(
sender_id=self.ap.im_mgr.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
)
]
forward = Forward(
display=display,
node_list=node_list
)
return [forward]

View File

@ -0,0 +1,197 @@
from __future__ import annotations
import typing
import os
import base64
import time
import re
from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent
from mirai.models.message import MessageComponent
from .. import strategy as strategy_model
class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8")
async def process(self, message: str) -> list[MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
)
compressed_path, size = self.compress_image(
img_path,
outfile="temp/{}_compressed.png".format(int(time.time()))
)
with open(compressed_path, 'rb') as f:
img = f.read()
b64 = base64.b64encode(img)
# 删除图片
os.remove(img_path)
if os.path.exists(compressed_path):
os.remove(compressed_path)
return [
ImageComponent(
base64=b64.decode('utf-8'),
)
]
def indexNumber(self, path=''):
"""
查找字符串中数字所在串中的位置
:param path:目标字符串
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
"""
kv = []
nums = []
beforeDatas = re.findall('[\d]+', path)
for num in beforeDatas:
indexV = []
times = path.count(num)
if times > 1:
if num not in nums:
indexs = re.finditer(num, path)
for index in indexs:
iV = []
i = index.span()[0]
iV.append(num)
iV.append(i)
kv.append(iV)
nums.append(num)
else:
index = path.find(num)
indexV.append(num)
indexV.append(index)
kv.append(indexV)
# 根据数字位置排序
indexSort = []
resultIndex = []
for vi in kv:
indexSort.append(vi[1])
indexSort.sort()
for i in indexSort:
for v in kv:
if i == v[1]:
resultIndex.append(v)
return resultIndex
def get_size(self, file):
# 获取文件大小:KB
size = os.path.getsize(file)
return size / 1024
def get_outfile(self, infile, outfile):
if outfile:
return outfile
dir, suffix = os.path.splitext(infile)
outfile = '{}-out{}'.format(dir, suffix)
return outfile
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
"""不改变图片尺寸压缩到指定大小
:param infile: 压缩源文件
:param outfile: 压缩文件保存地址
:param mb: 压缩目标,KB
:param step: 每次调整的压缩比率
:param quality: 初始压缩比率
:return: 压缩文件地址压缩文件大小
"""
o_size = self.get_size(infile)
if o_size <= kb:
return infile, o_size
outfile = self.get_outfile(infile, outfile)
while o_size > kb:
im = Image.open(infile)
im.save(outfile, quality=quality)
if quality - step < 0:
break
quality -= step
o_size = self.get_size(outfile)
return outfile, self.get_size(outfile)
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
text_str = text_str.replace("\t", " ")
# 分行
lines = text_str.split('\n')
# 计算并分割
final_lines = []
text_width = width-80
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
for line in lines:
# 如果长了就分割
line_width = self.text_render_font.getlength(line)
self.ap.logger.debug("line_width: {}".format(line_width))
if line_width < text_width:
final_lines.append(line)
continue
else:
rest_text = line
while True:
# 分割最前面的一行
point = int(len(rest_text) * (text_width / line_width))
# 检查断点是否在数字中间
numbers = self.indexNumber(rest_text)
for number in numbers:
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
point = number[1]
break
final_lines.append(rest_text[:point])
rest_text = rest_text[point:]
line_width = self.text_render_font.getlength(rest_text)
if line_width < text_width:
final_lines.append(rest_text)
break
else:
continue
# 准备画布
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
draw = ImageDraw.Draw(img, mode='RGBA')
self.ap.logger.debug("正在绘制图片...")
# 绘制正文
line_number = 0
offset_x = 20
offset_y = 30
for final_line in final_lines:
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font)
# 遍历此行,检查是否有emoji
idx_in_line = 0
for ch in final_line:
# 检查字符占位宽
char_code = ord(ch)
if char_code >= 127:
idx_in_line += 1
else:
idx_in_line += 0.5
line_number += 1
self.ap.logger.debug("正在保存图片...")
img.save(save_as)
return save_as

View File

@ -0,0 +1,22 @@
from __future__ import annotations
import abc
import typing
import mirai
from mirai.models.message import MessageComponent
from ...core import app
class LongTextStrategy(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def process(self, message: str) -> list[MessageComponent]:
return []

View File

View File

@ -0,0 +1,9 @@
import pydantic
import mirai
class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False
replacement: mirai.MessageChain = None

View File

@ -0,0 +1,62 @@
from __future__ import annotations
import mirai
from ...core import app
from . import entities as rule_entities, rule
from .rules import atbot, prefix, regexp, random
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
@stage.stage_class("GroupRespondRuleCheckStage")
class GroupRespondRuleCheckStage(stage.PipelineStage):
"""群组响应规则检查器
"""
rule_matchers: list[rule.GroupRespondRule]
async def initialize(self):
"""初始化检查器
"""
self.rule_matchers = [
atbot.AtBotRule(self.ap),
prefix.PrefixRule(self.ap),
regexp.RegExpRule(self.ap),
random.RandomRespRule(self.ap),
]
for rule_matcher in self.rule_matchers:
await rule_matcher.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type != 'group':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
rules = self.ap.cfg_mgr.data['response_rules']
use_rule = rules['default']
if str(query.launcher_id) in use_rule:
use_rule = use_rule[str(query.launcher_id)]
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
if res.matching:
query.message_chain = res.replacement
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query,
)
return entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)

View File

@ -0,0 +1,31 @@
from __future__ import annotations
import abc
import mirai
from ...core import app
from . import entities
class GroupRespondRule(metaclass=abc.ABCMeta):
"""群组响应规则的抽象类
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则
"""
raise NotImplementedError

View File

View File

@ -0,0 +1,28 @@
from __future__ import annotations
import mirai
from .. import rule as rule_model
from .. import entities
class AtBotRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id))
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(
matching=False,
replacement = message_chain
)

View File

@ -0,0 +1,29 @@
import mirai
from .. import rule as rule_model
from .. import entities
class PrefixRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix']
for prefix in prefixes:
if message_text.startswith(prefix):
return entities.RuleJudgeResult(
matching=True,
replacement=mirai.MessageChain([
mirai.Plain(message_text[len(prefix):])
]),
)
return entities.RuleJudgeResult(
matching=False,
replacement=message_chain
)

View File

@ -0,0 +1,22 @@
import random
import mirai
from .. import rule as rule_model
from .. import entities
class RandomRespRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random_rate']
return entities.RuleJudgeResult(
matching=random.random() < random_rate,
replacement=message_chain
)

View File

@ -0,0 +1,31 @@
import re
import mirai
from .. import rule as rule_model
from .. import entities
class RegExpRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp']
for regexp in regexps:
match = re.match(regexp, message_text)
if match:
return entities.RuleJudgeResult(
matching=True,
replacement=message_chain,
)
return entities.RuleJudgeResult(
matching=False,
replacement=message_chain
)

43
pkg/pipeline/stage.py Normal file
View File

@ -0,0 +1,43 @@
from __future__ import annotations
import abc
from ..core import app, entities as core_entities
from . import entities
_stage_classes: dict[str, PipelineStage] = {}
def stage_class(name: str):
def decorator(cls):
_stage_classes[name] = cls
return cls
return decorator
class PipelineStage(metaclass=abc.ABCMeta):
"""流水线阶段
"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化
"""
pass
@abc.abstractmethod
async def process(
self,
query: core_entities.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理
"""
raise NotImplementedError

47
pkg/pipeline/stagemgr.py Normal file
View File

@ -0,0 +1,47 @@
from __future__ import annotations
import pydantic
from ..core import app
from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .longtext import longtext
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class StageManager:
ap: app.Application
stage_containers: list[StageInstContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.stage_containers = []
async def initialize(self):
"""初始化
"""
for name, cls in stage._stage_classes.items():
self.stage_containers.append(StageInstContainer(
inst_name=name,
inst=cls(self.ap)
))
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()

View File

@ -3,7 +3,7 @@
from __future__ import annotations
import re
from ...boot import app
from ...core import app
from ...config import manager as cfg_mgr

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from ...boot import app
from ...core import app
from . import entities
from . import filter
from .filters import cntignore, banwords, baiduexamine

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import abc
from ...boot import app
from ...core import app
from . import entities

View File

@ -5,7 +5,7 @@ import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain
from ...boot import app
from ...core import app
from . import strategy
from .strategies import image, forward

View File

@ -5,7 +5,7 @@ import typing
import mirai
from mirai.models.message import MessageComponent
from ...boot import app
from ...core import app
class LongTextStrategy(metaclass=abc.ABCMeta):

View File

@ -24,7 +24,7 @@ from .cntfilter import cntfilter
from .longtext import longtext
from .ratelim import ratelim
from ..boot import app
from ..core import app, entities as core_entities
# 控制QQ消息输入输出的类
@ -91,45 +91,29 @@ class QQBotManager:
# Caution: 注册新的事件处理器之后请务必在unsubscribe_all中编写相应的取消订阅代码
async def on_friend_message(event: FriendMessage):
async def friend_message_handler():
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
if plugin_event.is_prevented_default():
return
await self.on_person_message(event)
asyncio.create_task(friend_message_handler())
self.adapter.register_listener(
FriendMessage,
on_friend_message
)
async def on_stranger_message(event: StrangerMessage):
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.PERSON,
launcher_id=event.sender.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
async def stranger_message_handler():
# 触发事件
args = {
"launcher_type": "person",
"launcher_id": event.sender.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
if plugin_event.is_prevented_default():
return
await self.on_person_message(event)
asyncio.create_task(stranger_message_handler())
# nakuru不区分好友和陌生人故仅为yirimirai注册陌生人事件
if config['msg_source_adapter'] == 'yirimirai':
self.adapter.register_listener(
@ -139,49 +123,19 @@ class QQBotManager:
async def on_group_message(event: GroupMessage):
async def group_message_handler(event: GroupMessage):
# 触发事件
args = {
"launcher_type": "group",
"launcher_id": event.group.id,
"sender_id": event.sender.id,
"message_chain": event.message_chain,
}
plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args)
if plugin_event.is_prevented_default():
return
await self.on_group_message(event)
asyncio.create_task(group_message_handler(event))
await self.ap.query_pool.add_query(
launcher_type=core_entities.LauncherTypes.GROUP,
launcher_id=event.group.id,
sender_id=event.sender.id,
message_event=event,
message_chain=event.message_chain
)
self.adapter.register_listener(
GroupMessage,
on_group_message
)
def unsubscribe_all():
"""取消所有订阅
用于在热重载流程中卸载所有事件处理器
"""
self.adapter.unregister_listener(
FriendMessage,
on_friend_message
)
if config['msg_source_adapter'] == 'yirimirai':
self.adapter.unregister_listener(
StrangerMessage,
on_stranger_message
)
self.adapter.unregister_listener(
GroupMessage,
on_group_message
)
self.unsubscribe_all = unsubscribe_all
async def send(self, event, msg, check_quote=True, check_at_sender=True):
config = context.get_config_manager().data

View File

@ -14,7 +14,7 @@ from ..utils import context
from ..plugin import host as plugin_host
from ..plugin import models as plugin_models
import tips as tips_custom
from ..boot import app
from ..core import app
from .cntfilter import entities
processing = []

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import abc
from ...boot import app
from ...core import app
class ReteLimitAlgo(metaclass=abc.ABCMeta):

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from . import algo
from .algos import fixedwin
from ...boot import app
from ...core import app
class RateLimiter:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import mirai
from ...boot import app
from ...core import app
from . import entities, rule
from .rules import atbot, prefix, regexp, random

View File

@ -3,7 +3,7 @@ import abc
import mirai
from ...boot import app
from ...core import app
from . import entities

View File

@ -1,6 +1,6 @@
import asyncio
from pkg.boot import boot
from pkg.core import boot
if __name__ == '__main__':