mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
refactor: 请求处理控制流基础架构
This commit is contained in:
parent
a064c24f60
commit
8d084427d2
|
@ -1,54 +0,0 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import colorlog
|
||||
|
||||
|
||||
log_colors_config = {
|
||||
'DEBUG': 'green', # cyan white
|
||||
'INFO': 'white',
|
||||
'WARNING': 'yellow',
|
||||
'ERROR': 'red',
|
||||
'CRITICAL': 'cyan',
|
||||
}
|
||||
|
||||
|
||||
async def init_logging() -> logging.Logger:
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
if 'DEBUG' in os.environ and os.environ['DEBUG'] in ['true', '1']:
|
||||
level = logging.DEBUG
|
||||
|
||||
log_file_name = "logs/qcg-%s.log" % time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
||||
|
||||
qcg_logger = logging.getLogger("qcg")
|
||||
|
||||
qcg_logger.setLevel(level)
|
||||
|
||||
log_handlers: logging.Handler = [
|
||||
logging.StreamHandler(sys.stdout),
|
||||
logging.FileHandler(log_file_name)
|
||||
]
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(
|
||||
colorlog.ColoredFormatter(
|
||||
fmt="[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
log_colors=log_colors_config
|
||||
)
|
||||
)
|
||||
qcg_logger.addHandler(handler)
|
||||
|
||||
logging.basicConfig(level=level, # 设置日志输出格式
|
||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
datefmt="%Y-%m-%d %H:%M:%S" # 时间输出的格式
|
||||
)
|
||||
|
||||
return qcg_logger
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from ..qqbot import manager as qqbot_mgr
|
||||
from ..openai import manager as openai_mgr
|
||||
|
@ -8,6 +9,8 @@ from ..config import manager as config_mgr
|
|||
from ..database import manager as database_mgr
|
||||
from ..utils.center import v2 as center_mgr
|
||||
from ..plugin import host as plugin_host
|
||||
from . import pool, controller
|
||||
from ..pipeline import stagemgr
|
||||
|
||||
|
||||
class Application:
|
||||
|
@ -23,16 +26,24 @@ class Application:
|
|||
|
||||
ctr_mgr: center_mgr.V2CenterAPI = None
|
||||
|
||||
query_pool: pool.QueryPool = None
|
||||
|
||||
ctrl: controller.Controller = None
|
||||
|
||||
stage_mgr: stagemgr.StageManager = None
|
||||
|
||||
logger: logging.Logger = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
async def initialize(self):
|
||||
await self.im_mgr.initialize()
|
||||
|
||||
async def run(self):
|
||||
# TODO make it async
|
||||
plugin_host.initialize_plugins()
|
||||
|
||||
await self.im_mgr.run()
|
||||
tasks = [
|
||||
asyncio.create_task(self.im_mgr.run()),
|
||||
asyncio.create_task(self.ctrl.run())
|
||||
]
|
||||
|
||||
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
@ -3,12 +3,15 @@ from __future__ import print_function
|
|||
import os
|
||||
import sys
|
||||
|
||||
from . import files
|
||||
from . import deps
|
||||
from . import log
|
||||
from . import config
|
||||
from .bootutils import files
|
||||
from .bootutils import deps
|
||||
from .bootutils import log
|
||||
from .bootutils import config
|
||||
|
||||
from . import app
|
||||
from . import pool
|
||||
from . import controller
|
||||
from ..pipeline import stagemgr
|
||||
from ..audit import identifier
|
||||
from ..database import manager as db_mgr
|
||||
from ..openai import manager as llm_mgr
|
||||
|
@ -86,6 +89,8 @@ async def make_app() -> app.Application:
|
|||
ap.cfg_mgr = cfg_mgr
|
||||
ap.tips_mgr = tips_mgr
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
center_v2_api = center_v2.V2CenterAPI(
|
||||
basic_info={
|
||||
"host_id": identifier.identifier['host_id'],
|
||||
|
@ -111,8 +116,16 @@ async def make_app() -> app.Application:
|
|||
llm_session.load_sessions()
|
||||
|
||||
im_mgr_inst = im_mgr.QQBotManager(first_time_init=True, ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
# TODO make it async
|
||||
plugin_host.load_plugins()
|
||||
# plugin_host.initialize_plugins()
|
||||
|
@ -122,5 +135,4 @@ async def make_app() -> app.Application:
|
|||
|
||||
async def main():
|
||||
app_inst = await make_app()
|
||||
await app_inst.initialize()
|
||||
await app_inst.run()
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
|
||||
from ..config import manager as config_mgr
|
||||
from ..config.impls import pymodule
|
||||
from ...config import manager as config_mgr
|
||||
from ...config.impls import pymodule
|
||||
|
||||
|
||||
load_python_module_config = config_mgr.load_python_module_config
|
56
pkg/core/bootutils/log.py
Normal file
56
pkg/core/bootutils/log.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import colorlog
|
||||
|
||||
|
||||
log_colors_config = {
|
||||
"DEBUG": "green", # cyan white
|
||||
"INFO": "white",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "cyan",
|
||||
}
|
||||
|
||||
|
||||
async def init_logging() -> logging.Logger:
|
||||
level = logging.INFO
|
||||
|
||||
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
|
||||
level = logging.DEBUG
|
||||
|
||||
log_file_name = "logs/qcg-%s.log" % time.strftime(
|
||||
"%Y-%m-%d-%H-%M-%S", time.localtime()
|
||||
)
|
||||
|
||||
qcg_logger = logging.getLogger("qcg")
|
||||
|
||||
qcg_logger.setLevel(level)
|
||||
|
||||
color_formatter = colorlog.ColoredFormatter(
|
||||
fmt="%(log_color)s[%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
log_colors=log_colors_config,
|
||||
)
|
||||
|
||||
stream_handler = logging.StreamHandler(sys.stdout)
|
||||
|
||||
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)]
|
||||
|
||||
for handler in log_handlers:
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(color_formatter)
|
||||
qcg_logger.addHandler(handler)
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # 设置日志输出格式
|
||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
datefmt="%Y-%m-%d %H:%M:%S", # 时间输出的格式
|
||||
handlers=[logging.NullHandler()],
|
||||
)
|
||||
|
||||
return qcg_logger
|
0
pkg/core/bootutils/misc.py
Normal file
0
pkg/core/bootutils/misc.py
Normal file
84
pkg/core/controller.py
Normal file
84
pkg/core/controller.py
Normal file
|
@ -0,0 +1,84 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from . import app, entities
|
||||
from ..pipeline import entities as pipeline_entities
|
||||
|
||||
DEFAULT_QUERY_CONCURRENCY = 10
|
||||
|
||||
|
||||
class Controller:
|
||||
"""总控制器
|
||||
"""
|
||||
ap: app.Application
|
||||
|
||||
semaphore: asyncio.Semaphore = None
|
||||
"""请求并发控制信号量"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
|
||||
|
||||
async def consumer(self):
|
||||
"""事件处理循环
|
||||
"""
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||
|
||||
if queries:
|
||||
selected_query = queries.pop(0) # FCFS
|
||||
else:
|
||||
await self.ap.query_pool.condition.wait()
|
||||
continue
|
||||
|
||||
if selected_query:
|
||||
async def _process_query(selected_query):
|
||||
async with self.semaphore:
|
||||
await self.process_query(selected_query)
|
||||
|
||||
asyncio.create_task(_process_query(selected_query))
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
"""处理请求
|
||||
"""
|
||||
self.ap.logger.debug(f"Processing query {query}")
|
||||
|
||||
try:
|
||||
for stage_container in self.ap.stage_mgr.stage_containers:
|
||||
res = await stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} res {res}")
|
||||
|
||||
if res.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
query.message_event,
|
||||
res.user_notice
|
||||
)
|
||||
if res.debug_notice:
|
||||
self.ap.logger.debug(res.debug_notice)
|
||||
if res.console_notice:
|
||||
self.ap.logger.info(res.console_notice)
|
||||
|
||||
if res.result_type == pipeline_entities.ResultType.INTERRUPT:
|
||||
self.ap.logger.debug(f"Stage {stage_container.inst_name} interrupted query {query}")
|
||||
break
|
||||
elif res.result_type == pipeline_entities.ResultType.CONTINUE:
|
||||
query = res.new_query
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
self.ap.logger.debug(f"Query {query} processed")
|
||||
|
||||
async def run(self):
|
||||
"""运行控制器
|
||||
"""
|
||||
await self.consumer()
|
41
pkg/core/entities.py
Normal file
41
pkg/core/entities.py
Normal file
|
@ -0,0 +1,41 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
|
||||
PERSON = 'person'
|
||||
"""私聊"""
|
||||
|
||||
GROUP = 'group'
|
||||
"""群聊"""
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""一次请求的信息封装"""
|
||||
|
||||
query_id: int
|
||||
"""请求ID"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型"""
|
||||
|
||||
launcher_id: int
|
||||
"""会话ID"""
|
||||
|
||||
sender_id: int
|
||||
"""发送者ID"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
"""事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
"""回复消息链"""
|
52
pkg/core/pool.py
Normal file
52
pkg/core/pool.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import mirai
|
||||
|
||||
from . import entities
|
||||
|
||||
|
||||
class QueryPool:
|
||||
|
||||
query_id_counter: int = 0
|
||||
|
||||
pool_lock: asyncio.Lock
|
||||
|
||||
queries: list[entities.Query]
|
||||
|
||||
condition: asyncio.Condition
|
||||
|
||||
def __init__(self):
|
||||
self.query_id_counter = 0
|
||||
self.pool_lock = asyncio.Lock()
|
||||
self.queries = []
|
||||
self.condition = asyncio.Condition(self.pool_lock)
|
||||
|
||||
async def add_query(
|
||||
self,
|
||||
launcher_type: entities.LauncherTypes,
|
||||
launcher_id: int,
|
||||
sender_id: int,
|
||||
message_event: mirai.MessageEvent,
|
||||
message_chain: mirai.MessageChain
|
||||
) -> entities.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
query_id=self.query_id_counter,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.query_id_counter += 1
|
||||
self.condition.notify_all()
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.pool_lock.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
self.pool_lock.release()
|
|
@ -10,7 +10,7 @@ from ..utils import context
|
|||
from ..audit import gatherer
|
||||
from ..openai import modelmgr
|
||||
from ..openai.api import model as api_model
|
||||
from ..boot import app
|
||||
from ..core import app
|
||||
|
||||
|
||||
class OpenAIInteract:
|
||||
|
|
0
pkg/pipeline/__init__.py
Normal file
0
pkg/pipeline/__init__.py
Normal file
0
pkg/pipeline/bansess/__init__.py
Normal file
0
pkg/pipeline/bansess/__init__.py
Normal file
76
pkg/pipeline/bansess/bansess.py
Normal file
76
pkg/pipeline/bansess/bansess.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
class BanSessionCheckStage(stage.PipelineStage):
|
||||
|
||||
banlist_mgr: cfg_mgr.ConfigManager
|
||||
|
||||
async def initialize(self):
|
||||
self.banlist_mgr = await cfg_mgr.load_python_module_config(
|
||||
"banlist.py",
|
||||
"res/templates/banlist-template.py"
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
|
||||
if not self.banlist_mgr.data['enable']:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
result = False
|
||||
|
||||
if query.launcher_type == 'group':
|
||||
if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应
|
||||
result = True
|
||||
# 检查是否显式声明发起人QQ要被person忽略
|
||||
elif query.sender_id in self.banlist_mgr.data['person']:
|
||||
result = True
|
||||
else:
|
||||
for group_rule in self.banlist_mgr.data['group']:
|
||||
if type(group_rule) == int:
|
||||
if group_rule == query.launcher_id:
|
||||
result = True
|
||||
elif type(group_rule) == str:
|
||||
if group_rule.startswith('!'):
|
||||
reg_str = group_rule[1:]
|
||||
if re.match(reg_str, str(query.launcher_id)):
|
||||
result = False
|
||||
break
|
||||
else:
|
||||
if re.match(group_rule, str(query.launcher_id)):
|
||||
result = True
|
||||
elif query.launcher_type == 'person':
|
||||
if not self.banlist_mgr.data['enable_private']:
|
||||
result = True
|
||||
else:
|
||||
for person_rule in self.banlist_mgr.data['person']:
|
||||
if type(person_rule) == int:
|
||||
if person_rule == query.launcher_id:
|
||||
result = True
|
||||
elif type(person_rule) == str:
|
||||
if person_rule.startswith('!'):
|
||||
reg_str = person_rule[1:]
|
||||
if re.match(reg_str, str(query.launcher_id)):
|
||||
result = False
|
||||
break
|
||||
else:
|
||||
if re.match(person_rule, str(query.launcher_id)):
|
||||
result = True
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
debug_notice=f'根据禁用列表忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
|
||||
)
|
0
pkg/pipeline/cntfilter/__init__.py
Normal file
0
pkg/pipeline/cntfilter/__init__.py
Normal file
128
pkg/pipeline/cntfilter/cntfilter.py
Normal file
128
pkg/pipeline/cntfilter/cntfilter.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@stage.stage_class('PreContentFilterStage')
|
||||
class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
filter_chain: list[filter.ContentFilter]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.filter_chain = []
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self):
|
||||
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
|
||||
|
||||
if self.ap.cfg_mgr.data['sensitive_word_filter']:
|
||||
self.filter_chain.append(banwords.BanWordFilter(self.ap))
|
||||
|
||||
if self.ap.cfg_mgr.data['baidu_check']:
|
||||
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
|
||||
|
||||
for filter in self.filter_chain:
|
||||
await filter.initialize()
|
||||
|
||||
async def _pre_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm前处理消息
|
||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||
"""
|
||||
if not self.ap.cfg_mgr.data['income_msg_check']:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
|
||||
if result.level in [
|
||||
filter_entities.ResultLevel.BLOCK,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
]:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
)
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = mirai.MessageChain(
|
||||
mirai.Plain(message)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm后处理响应
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
"""
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.POST in filter.enable_stages:
|
||||
result = await filter.process(message)
|
||||
|
||||
if result.level == filter_entities.ResultLevel.BLOCK:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice=result.user_notice,
|
||||
console_notice=result.console_notice
|
||||
)
|
||||
elif result.level in [
|
||||
filter_entities.ResultLevel.PASS,
|
||||
filter_entities.ResultLevel.MASKED
|
||||
]:
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = mirai.MessageChain(
|
||||
mirai.Plain(message)
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
return await self._pre_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
return await self._post_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
64
pkg/pipeline/cntfilter/entities.py
Normal file
64
pkg/pipeline/cntfilter/entities.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
|
||||
import typing
|
||||
import enum
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
"""结果等级"""
|
||||
PASS = enum.auto()
|
||||
"""通过"""
|
||||
|
||||
WARN = enum.auto()
|
||||
"""警告"""
|
||||
|
||||
MASKED = enum.auto()
|
||||
"""已掩去"""
|
||||
|
||||
BLOCK = enum.auto()
|
||||
"""阻止"""
|
||||
|
||||
|
||||
class EnableStage(enum.Enum):
|
||||
"""启用阶段"""
|
||||
PRE = enum.auto()
|
||||
"""预处理"""
|
||||
|
||||
POST = enum.auto()
|
||||
"""后处理"""
|
||||
|
||||
|
||||
class FilterResult(pydantic.BaseModel):
|
||||
level: ResultLevel
|
||||
|
||||
replacement: str
|
||||
"""替换后的消息"""
|
||||
|
||||
user_notice: str
|
||||
"""不通过时,用户提示消息"""
|
||||
|
||||
console_notice: str
|
||||
"""不通过时,控制台提示消息"""
|
||||
|
||||
|
||||
class ManagerResultLevel(enum.Enum):
|
||||
"""处理器结果等级"""
|
||||
CONTINUE = enum.auto()
|
||||
"""继续"""
|
||||
|
||||
INTERRUPT = enum.auto()
|
||||
"""中断"""
|
||||
|
||||
class FilterManagerResult(pydantic.BaseModel):
|
||||
|
||||
level: ManagerResultLevel
|
||||
|
||||
replacement: str
|
||||
"""替换后的消息"""
|
||||
|
||||
user_notice: str
|
||||
"""用户提示消息"""
|
||||
|
||||
console_notice: str
|
||||
"""控制台提示消息"""
|
34
pkg/pipeline/cntfilter/filter.py
Normal file
34
pkg/pipeline/cntfilter/filter.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# 内容过滤器的抽象类
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
class ContentFilter(metaclass=abc.ABCMeta):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
@property
|
||||
def enable_stages(self):
|
||||
"""启用的阶段
|
||||
"""
|
||||
return [
|
||||
entities.EnableStage.PRE,
|
||||
entities.EnableStage.POST
|
||||
]
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化过滤器
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
"""
|
||||
raise NotImplementedError
|
0
pkg/pipeline/cntfilter/filters/__init__.py
Normal file
0
pkg/pipeline/cntfilter/filters/__init__.py
Normal file
61
pkg/pipeline/cntfilter/filters/baiduexamine.py
Normal file
61
pkg/pipeline/cntfilter/filters/baiduexamine.py
Normal file
|
@ -0,0 +1,61 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import aiohttp
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
|
||||
|
||||
BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}"
|
||||
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
|
||||
|
||||
class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
"""百度云内容审核"""
|
||||
|
||||
async def _get_token(self) -> str:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_TOKEN_URL,
|
||||
params={
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": self.ap.cfg_mgr.data['baidu_api_key'],
|
||||
"client_secret": self.ap.cfg_mgr.data['baidu_secret_key']
|
||||
}
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
headers={'Content-Type': 'application/x-www-form-urlencoded', 'Accept': 'application/json'},
|
||||
data=f"text={message}".encode('utf-8')
|
||||
) as resp:
|
||||
result = await resp.json()
|
||||
|
||||
if "error_code" in result:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f"百度云判定出错,错误信息:{result['error_msg']}"
|
||||
)
|
||||
else:
|
||||
conclusion = result["conclusion"]
|
||||
|
||||
if conclusion in ("合规"):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=f"百度云判定结果:{conclusion}"
|
||||
)
|
||||
else:
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement=message,
|
||||
user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'],
|
||||
console_notice=f"百度云判定结果:{conclusion}"
|
||||
)
|
44
pkg/pipeline/cntfilter/filters/banwords.py
Normal file
44
pkg/pipeline/cntfilter/filters/banwords.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from .. import filter as filter_model
|
||||
from .. import entities
|
||||
from ....config import manager as cfg_mgr
|
||||
|
||||
|
||||
class BanWordFilter(filter_model.ContentFilter):
|
||||
"""根据内容禁言"""
|
||||
|
||||
sensitive: cfg_mgr.ConfigManager
|
||||
|
||||
async def initialize(self):
|
||||
self.sensitive = await cfg_mgr.load_json_config(
|
||||
"sensitive.json",
|
||||
"res/templates/sensitive-template.json"
|
||||
)
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.sensitive.data['words']:
|
||||
match = re.findall(word, message)
|
||||
|
||||
if len(match) > 0:
|
||||
found = True
|
||||
|
||||
for i in range(len(match)):
|
||||
if self.sensitive.data['mask_word'] == "":
|
||||
message = message.replace(
|
||||
match[i], self.sensitive.data['mask'] * len(match[i])
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
match[i], self.sensitive.data['mask_word']
|
||||
)
|
||||
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '',
|
||||
console_notice=''
|
||||
)
|
43
pkg/pipeline/cntfilter/filters/cntignore.py
Normal file
43
pkg/pipeline/cntfilter/filters/cntignore.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
|
||||
|
||||
class ContentIgnore(filter_model.ContentFilter):
|
||||
"""根据内容忽略消息"""
|
||||
|
||||
@property
|
||||
def enable_stages(self):
|
||||
return [
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']:
|
||||
for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement='',
|
||||
user_notice='',
|
||||
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
|
||||
)
|
||||
|
||||
if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']:
|
||||
for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']:
|
||||
if re.search(rule, message):
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.BLOCK,
|
||||
replacement='',
|
||||
user_notice='',
|
||||
console_notice='根据 ignore_rules 中的 regexp 规则,忽略消息'
|
||||
)
|
||||
|
||||
return entities.FilterResult(
|
||||
level=entities.ResultLevel.PASS,
|
||||
replacement=message,
|
||||
user_notice='',
|
||||
console_notice=''
|
||||
)
|
38
pkg/pipeline/entities.py
Normal file
38
pkg/pipeline/entities.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
import mirai.models.message as mirai_message
|
||||
|
||||
from ..core import entities
|
||||
|
||||
|
||||
class ResultType(enum.Enum):
|
||||
|
||||
CONTINUE = enum.auto()
|
||||
"""继续流水线"""
|
||||
|
||||
INTERRUPT = enum.auto()
|
||||
"""中断流水线"""
|
||||
|
||||
|
||||
class StageProcessResult(pydantic.BaseModel):
|
||||
|
||||
result_type: ResultType
|
||||
|
||||
new_query: entities.Query
|
||||
|
||||
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
|
||||
"""只要设置了就会发送给用户"""
|
||||
|
||||
admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
|
||||
"""只要设置了就会发送给管理员"""
|
||||
|
||||
console_notice: typing.Optional[str] = ''
|
||||
"""只要设置了就会输出到控制台"""
|
||||
|
||||
debug_notice: typing.Optional[str] = ''
|
||||
|
0
pkg/pipeline/longtext/__init__.py
Normal file
0
pkg/pipeline/longtext/__init__.py
Normal file
57
pkg/pipeline/longtext/longtext.py
Normal file
57
pkg/pipeline/longtext/longtext.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
from __future__ import annotations
|
||||
import os
|
||||
import traceback
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from mirai.models.message import MessageComponent, Plain, MessageChain
|
||||
|
||||
from ...core import app
|
||||
from . import strategy
|
||||
from .strategies import image, forward
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("LongTextProcessStage")
|
||||
class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
|
||||
async def initialize(self):
|
||||
config = self.ap.cfg_mgr.data
|
||||
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
|
||||
use_font = config['font_path']
|
||||
try:
|
||||
# 检查是否存在
|
||||
if not os.path.exists(use_font):
|
||||
# 若是windows系统,使用微软雅黑
|
||||
if os.name == "nt":
|
||||
use_font = "C:/Windows/Fonts/msyh.ttc"
|
||||
if not os.path.exists(use_font):
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||
config['blob_message_strategy'] = "forward"
|
||||
else:
|
||||
self.ap.logger.info("使用Windows自带字体:" + use_font)
|
||||
self.ap.cfg_mgr.data['font_path'] = use_font
|
||||
else:
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
|
||||
except:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font))
|
||||
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
|
||||
|
||||
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
|
||||
self.strategy_impl = image.Text2ImageStrategy(self.ap)
|
||||
elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward':
|
||||
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']:
|
||||
query.message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
0
pkg/pipeline/longtext/strategies/__init__.py
Normal file
0
pkg/pipeline/longtext/strategies/__init__.py
Normal file
62
pkg/pipeline/longtext/strategies/forward.py
Normal file
62
pkg/pipeline/longtext/strategies/forward.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
# 转发消息组件
|
||||
from __future__ import annotations
|
||||
import typing
|
||||
|
||||
from mirai.models import MessageChain
|
||||
from mirai.models.message import MessageComponent, ForwardMessageNode
|
||||
from mirai.models.base import MiraiBaseModel
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
|
||||
|
||||
class ForwardMessageDiaplay(MiraiBaseModel):
|
||||
title: str = "群聊的聊天记录"
|
||||
brief: str = "[聊天记录]"
|
||||
source: str = "聊天记录"
|
||||
preview: typing.List[str] = []
|
||||
summary: str = "查看x条转发消息"
|
||||
|
||||
|
||||
class Forward(MessageComponent):
|
||||
"""合并转发。"""
|
||||
type: str = "Forward"
|
||||
"""消息组件类型。"""
|
||||
display: ForwardMessageDiaplay
|
||||
"""显示信息"""
|
||||
node_list: typing.List[ForwardMessageNode]
|
||||
"""转发消息节点列表。"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) == 1:
|
||||
self.node_list = args[0]
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return '[聊天记录]'
|
||||
|
||||
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
async def process(self, message: str) -> list[MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title="群聊的聊天记录",
|
||||
brief="[聊天记录]",
|
||||
source="聊天记录",
|
||||
preview=["QQ用户: "+message],
|
||||
summary="查看1条转发消息"
|
||||
)
|
||||
|
||||
node_list = [
|
||||
ForwardMessageNode(
|
||||
sender_id=self.ap.im_mgr.bot_account_id,
|
||||
sender_name='QQ用户',
|
||||
message_chain=MessageChain([message])
|
||||
)
|
||||
]
|
||||
|
||||
forward = Forward(
|
||||
display=display,
|
||||
node_list=node_list
|
||||
)
|
||||
|
||||
return [forward]
|
197
pkg/pipeline/longtext/strategies/image.py
Normal file
197
pkg/pipeline/longtext/strategies/image.py
Normal file
|
@ -0,0 +1,197 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import os
|
||||
import base64
|
||||
import time
|
||||
import re
|
||||
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from mirai.models import MessageChain, Image as ImageComponent
|
||||
from mirai.models.message import MessageComponent
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
|
||||
|
||||
class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
text_render_font: ImageFont.FreeTypeFont
|
||||
|
||||
async def initialize(self):
|
||||
self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8")
|
||||
|
||||
async def process(self, message: str) -> list[MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time()))
|
||||
)
|
||||
|
||||
compressed_path, size = self.compress_image(
|
||||
img_path,
|
||||
outfile="temp/{}_compressed.png".format(int(time.time()))
|
||||
)
|
||||
|
||||
with open(compressed_path, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
b64 = base64.b64encode(img)
|
||||
|
||||
# 删除图片
|
||||
os.remove(img_path)
|
||||
|
||||
if os.path.exists(compressed_path):
|
||||
os.remove(compressed_path)
|
||||
|
||||
return [
|
||||
ImageComponent(
|
||||
base64=b64.decode('utf-8'),
|
||||
)
|
||||
]
|
||||
|
||||
def indexNumber(self, path=''):
|
||||
"""
|
||||
查找字符串中数字所在串中的位置
|
||||
:param path:目标字符串
|
||||
:return:<class 'list'>: <class 'list'>: [['1', 16], ['2', 35], ['1', 51]]
|
||||
"""
|
||||
kv = []
|
||||
nums = []
|
||||
beforeDatas = re.findall('[\d]+', path)
|
||||
for num in beforeDatas:
|
||||
indexV = []
|
||||
times = path.count(num)
|
||||
if times > 1:
|
||||
if num not in nums:
|
||||
indexs = re.finditer(num, path)
|
||||
for index in indexs:
|
||||
iV = []
|
||||
i = index.span()[0]
|
||||
iV.append(num)
|
||||
iV.append(i)
|
||||
kv.append(iV)
|
||||
nums.append(num)
|
||||
else:
|
||||
index = path.find(num)
|
||||
indexV.append(num)
|
||||
indexV.append(index)
|
||||
kv.append(indexV)
|
||||
# 根据数字位置排序
|
||||
indexSort = []
|
||||
resultIndex = []
|
||||
for vi in kv:
|
||||
indexSort.append(vi[1])
|
||||
indexSort.sort()
|
||||
for i in indexSort:
|
||||
for v in kv:
|
||||
if i == v[1]:
|
||||
resultIndex.append(v)
|
||||
return resultIndex
|
||||
|
||||
|
||||
def get_size(self, file):
|
||||
# 获取文件大小:KB
|
||||
size = os.path.getsize(file)
|
||||
return size / 1024
|
||||
|
||||
|
||||
def get_outfile(self, infile, outfile):
|
||||
if outfile:
|
||||
return outfile
|
||||
dir, suffix = os.path.splitext(infile)
|
||||
outfile = '{}-out{}'.format(dir, suffix)
|
||||
return outfile
|
||||
|
||||
|
||||
def compress_image(self, infile, outfile='', kb=100, step=20, quality=90):
|
||||
"""不改变图片尺寸压缩到指定大小
|
||||
:param infile: 压缩源文件
|
||||
:param outfile: 压缩文件保存地址
|
||||
:param mb: 压缩目标,KB
|
||||
:param step: 每次调整的压缩比率
|
||||
:param quality: 初始压缩比率
|
||||
:return: 压缩文件地址,压缩文件大小
|
||||
"""
|
||||
o_size = self.get_size(infile)
|
||||
if o_size <= kb:
|
||||
return infile, o_size
|
||||
outfile = self.get_outfile(infile, outfile)
|
||||
while o_size > kb:
|
||||
im = Image.open(infile)
|
||||
im.save(outfile, quality=quality)
|
||||
if quality - step < 0:
|
||||
break
|
||||
quality -= step
|
||||
o_size = self.get_size(outfile)
|
||||
return outfile, self.get_size(outfile)
|
||||
|
||||
|
||||
def text_to_image(self, text_str: str, save_as="temp.png", width=800):
|
||||
|
||||
text_str = text_str.replace("\t", " ")
|
||||
|
||||
# 分行
|
||||
lines = text_str.split('\n')
|
||||
|
||||
# 计算并分割
|
||||
final_lines = []
|
||||
|
||||
text_width = width-80
|
||||
|
||||
self.ap.logger.debug("lines: {}, text_width: {}".format(lines, text_width))
|
||||
for line in lines:
|
||||
# 如果长了就分割
|
||||
line_width = self.text_render_font.getlength(line)
|
||||
self.ap.logger.debug("line_width: {}".format(line_width))
|
||||
if line_width < text_width:
|
||||
final_lines.append(line)
|
||||
continue
|
||||
else:
|
||||
rest_text = line
|
||||
while True:
|
||||
# 分割最前面的一行
|
||||
point = int(len(rest_text) * (text_width / line_width))
|
||||
|
||||
# 检查断点是否在数字中间
|
||||
numbers = self.indexNumber(rest_text)
|
||||
|
||||
for number in numbers:
|
||||
if number[1] < point < number[1] + len(number[0]) and number[1] != 0:
|
||||
point = number[1]
|
||||
break
|
||||
|
||||
final_lines.append(rest_text[:point])
|
||||
rest_text = rest_text[point:]
|
||||
line_width = self.text_render_font.getlength(rest_text)
|
||||
if line_width < text_width:
|
||||
final_lines.append(rest_text)
|
||||
break
|
||||
else:
|
||||
continue
|
||||
# 准备画布
|
||||
img = Image.new('RGBA', (width, max(280, len(final_lines) * 35 + 65)), (255, 255, 255, 255))
|
||||
draw = ImageDraw.Draw(img, mode='RGBA')
|
||||
|
||||
self.ap.logger.debug("正在绘制图片...")
|
||||
# 绘制正文
|
||||
line_number = 0
|
||||
offset_x = 20
|
||||
offset_y = 30
|
||||
for final_line in final_lines:
|
||||
draw.text((offset_x, offset_y + 35 * line_number), final_line, fill=(0, 0, 0), font=self.text_render_font)
|
||||
# 遍历此行,检查是否有emoji
|
||||
idx_in_line = 0
|
||||
for ch in final_line:
|
||||
# 检查字符占位宽
|
||||
char_code = ord(ch)
|
||||
if char_code >= 127:
|
||||
idx_in_line += 1
|
||||
else:
|
||||
idx_in_line += 0.5
|
||||
|
||||
line_number += 1
|
||||
|
||||
self.ap.logger.debug("正在保存图片...")
|
||||
img.save(save_as)
|
||||
|
||||
return save_as
|
22
pkg/pipeline/longtext/strategy.py
Normal file
22
pkg/pipeline/longtext/strategy.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
from mirai.models.message import MessageComponent
|
||||
|
||||
from ...core import app
|
||||
|
||||
|
||||
class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> list[MessageComponent]:
|
||||
return []
|
0
pkg/pipeline/resprule/__init__.py
Normal file
0
pkg/pipeline/resprule/__init__.py
Normal file
9
pkg/pipeline/resprule/entities.py
Normal file
9
pkg/pipeline/resprule/entities.py
Normal file
|
@ -0,0 +1,9 @@
|
|||
import pydantic
|
||||
import mirai
|
||||
|
||||
|
||||
class RuleJudgeResult(pydantic.BaseModel):
|
||||
|
||||
matching: bool = False
|
||||
|
||||
replacement: mirai.MessageChain = None
|
62
pkg/pipeline/resprule/resprule.py
Normal file
62
pkg/pipeline/resprule/resprule.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
from . import entities as rule_entities, rule
|
||||
from .rules import atbot, prefix, regexp, random
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
@stage.stage_class("GroupRespondRuleCheckStage")
|
||||
class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
"""群组响应规则检查器
|
||||
"""
|
||||
|
||||
rule_matchers: list[rule.GroupRespondRule]
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化检查器
|
||||
"""
|
||||
self.rule_matchers = [
|
||||
atbot.AtBotRule(self.ap),
|
||||
prefix.PrefixRule(self.ap),
|
||||
regexp.RegExpRule(self.ap),
|
||||
random.RandomRespRule(self.ap),
|
||||
]
|
||||
|
||||
for rule_matcher in self.rule_matchers:
|
||||
await rule_matcher.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
if query.launcher_type != 'group':
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
rules = self.ap.cfg_mgr.data['response_rules']
|
||||
|
||||
use_rule = rules['default']
|
||||
|
||||
if str(query.launcher_id) in use_rule:
|
||||
use_rule = use_rule[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule)
|
||||
if res.matching:
|
||||
query.message_chain = res.replacement
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
31
pkg/pipeline/resprule/rule.py
Normal file
31
pkg/pipeline/resprule/rule.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
import mirai
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
class GroupRespondRule(metaclass=abc.ABCMeta):
|
||||
"""群组响应规则的抽象类
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
) -> entities.RuleJudgeResult:
|
||||
"""判断消息是否匹配规则
|
||||
"""
|
||||
raise NotImplementedError
|
0
pkg/pipeline/resprule/rules/__init__.py
Normal file
0
pkg/pipeline/resprule/rules/__init__.py
Normal file
28
pkg/pipeline/resprule/rules/atbot.py
Normal file
28
pkg/pipeline/resprule/rules/atbot.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
|
||||
|
||||
class AtBotRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
) -> entities.RuleJudgeResult:
|
||||
|
||||
if message_chain.has(mirai.At(self.ap.im_mgr.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(mirai.At(self.ap.im_mgr.bot_account_id))
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement = message_chain
|
||||
)
|
29
pkg/pipeline/resprule/rules/prefix.py
Normal file
29
pkg/pipeline/resprule/rules/prefix.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
|
||||
|
||||
class PrefixRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
) -> entities.RuleJudgeResult:
|
||||
prefixes = rule_dict['prefix']
|
||||
|
||||
for prefix in prefixes:
|
||||
if message_text.startswith(prefix):
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=mirai.MessageChain([
|
||||
mirai.Plain(message_text[len(prefix):])
|
||||
]),
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement=message_chain
|
||||
)
|
22
pkg/pipeline/resprule/rules/random.py
Normal file
22
pkg/pipeline/resprule/rules/random.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import random
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
|
||||
|
||||
class RandomRespRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
) -> entities.RuleJudgeResult:
|
||||
random_rate = rule_dict['random_rate']
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=random.random() < random_rate,
|
||||
replacement=message_chain
|
||||
)
|
31
pkg/pipeline/resprule/rules/regexp.py
Normal file
31
pkg/pipeline/resprule/rules/regexp.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
import re
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import rule as rule_model
|
||||
from .. import entities
|
||||
|
||||
|
||||
class RegExpRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
self,
|
||||
message_text: str,
|
||||
message_chain: mirai.MessageChain,
|
||||
rule_dict: dict
|
||||
) -> entities.RuleJudgeResult:
|
||||
regexps = rule_dict['regexp']
|
||||
|
||||
for regexp in regexps:
|
||||
match = re.match(regexp, message_text)
|
||||
|
||||
if match:
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=False,
|
||||
replacement=message_chain
|
||||
)
|
43
pkg/pipeline/stage.py
Normal file
43
pkg/pipeline/stage.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
_stage_classes: dict[str, PipelineStage] = {}
|
||||
|
||||
|
||||
def stage_class(name: str):
|
||||
|
||||
def decorator(cls):
|
||||
_stage_classes[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class PipelineStage(metaclass=abc.ABCMeta):
|
||||
"""流水线阶段
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
raise NotImplementedError
|
47
pkg/pipeline/stagemgr.py
Normal file
47
pkg/pipeline/stagemgr.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import pydantic
|
||||
|
||||
from ..core import app
|
||||
from . import stage
|
||||
from .resprule import resprule
|
||||
from .bansess import bansess
|
||||
from .cntfilter import cntfilter
|
||||
from .longtext import longtext
|
||||
|
||||
|
||||
class StageInstContainer():
|
||||
"""阶段实例容器
|
||||
"""
|
||||
|
||||
inst_name: str
|
||||
|
||||
inst: stage.PipelineStage
|
||||
|
||||
def __init__(self, inst_name: str, inst: stage.PipelineStage):
|
||||
self.inst_name = inst_name
|
||||
self.inst = inst
|
||||
|
||||
|
||||
class StageManager:
|
||||
ap: app.Application
|
||||
|
||||
stage_containers: list[StageInstContainer]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.stage_containers = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化
|
||||
"""
|
||||
|
||||
for name, cls in stage._stage_classes.items():
|
||||
self.stage_containers.append(StageInstContainer(
|
||||
inst_name=name,
|
||||
inst=cls(self.ap)
|
||||
))
|
||||
|
||||
for stage_containers in self.stage_containers:
|
||||
await stage_containers.inst.initialize()
|
|
@ -3,7 +3,7 @@
|
|||
from __future__ import annotations
|
||||
import re
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import entities
|
||||
from . import filter
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import traceback
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
from mirai.models.message import MessageComponent, Plain
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import strategy
|
||||
from .strategies import image, forward
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import typing
|
|||
import mirai
|
||||
from mirai.models.message import MessageComponent
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
|
||||
|
||||
class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
|
|
|
@ -24,7 +24,7 @@ from .cntfilter import cntfilter
|
|||
from .longtext import longtext
|
||||
from .ratelim import ratelim
|
||||
|
||||
from ..boot import app
|
||||
from ..core import app, entities as core_entities
|
||||
|
||||
|
||||
# 控制QQ消息输入输出的类
|
||||
|
@ -91,45 +91,29 @@ class QQBotManager:
|
|||
# Caution: 注册新的事件处理器之后,请务必在unsubscribe_all中编写相应的取消订阅代码
|
||||
async def on_friend_message(event: FriendMessage):
|
||||
|
||||
async def friend_message_handler():
|
||||
# 触发事件
|
||||
args = {
|
||||
"launcher_type": "person",
|
||||
"launcher_id": event.sender.id,
|
||||
"sender_id": event.sender.id,
|
||||
"message_chain": event.message_chain,
|
||||
}
|
||||
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
if plugin_event.is_prevented_default():
|
||||
return
|
||||
|
||||
await self.on_person_message(event)
|
||||
|
||||
asyncio.create_task(friend_message_handler())
|
||||
self.adapter.register_listener(
|
||||
FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
|
||||
async def on_stranger_message(event: StrangerMessage):
|
||||
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
||||
launcher_id=event.sender.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
async def stranger_message_handler():
|
||||
# 触发事件
|
||||
args = {
|
||||
"launcher_type": "person",
|
||||
"launcher_id": event.sender.id,
|
||||
"sender_id": event.sender.id,
|
||||
"message_chain": event.message_chain,
|
||||
}
|
||||
plugin_event = plugin_host.emit(plugin_models.PersonMessageReceived, **args)
|
||||
|
||||
if plugin_event.is_prevented_default():
|
||||
return
|
||||
|
||||
await self.on_person_message(event)
|
||||
|
||||
asyncio.create_task(stranger_message_handler())
|
||||
# nakuru不区分好友和陌生人,故仅为yirimirai注册陌生人事件
|
||||
if config['msg_source_adapter'] == 'yirimirai':
|
||||
self.adapter.register_listener(
|
||||
|
@ -139,49 +123,19 @@ class QQBotManager:
|
|||
|
||||
async def on_group_message(event: GroupMessage):
|
||||
|
||||
async def group_message_handler(event: GroupMessage):
|
||||
# 触发事件
|
||||
args = {
|
||||
"launcher_type": "group",
|
||||
"launcher_id": event.group.id,
|
||||
"sender_id": event.sender.id,
|
||||
"message_chain": event.message_chain,
|
||||
}
|
||||
plugin_event = plugin_host.emit(plugin_models.GroupMessageReceived, **args)
|
||||
|
||||
if plugin_event.is_prevented_default():
|
||||
return
|
||||
|
||||
await self.on_group_message(event)
|
||||
|
||||
asyncio.create_task(group_message_handler(event))
|
||||
await self.ap.query_pool.add_query(
|
||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_event=event,
|
||||
message_chain=event.message_chain
|
||||
)
|
||||
|
||||
self.adapter.register_listener(
|
||||
GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
def unsubscribe_all():
|
||||
"""取消所有订阅
|
||||
|
||||
用于在热重载流程中卸载所有事件处理器
|
||||
"""
|
||||
self.adapter.unregister_listener(
|
||||
FriendMessage,
|
||||
on_friend_message
|
||||
)
|
||||
if config['msg_source_adapter'] == 'yirimirai':
|
||||
self.adapter.unregister_listener(
|
||||
StrangerMessage,
|
||||
on_stranger_message
|
||||
)
|
||||
self.adapter.unregister_listener(
|
||||
GroupMessage,
|
||||
on_group_message
|
||||
)
|
||||
|
||||
self.unsubscribe_all = unsubscribe_all
|
||||
|
||||
async def send(self, event, msg, check_quote=True, check_at_sender=True):
|
||||
config = context.get_config_manager().data
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from ..utils import context
|
|||
from ..plugin import host as plugin_host
|
||||
from ..plugin import models as plugin_models
|
||||
import tips as tips_custom
|
||||
from ..boot import app
|
||||
from ..core import app
|
||||
from .cntfilter import entities
|
||||
|
||||
processing = []
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
from . import algo
|
||||
from .algos import fixedwin
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import mirai
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import entities, rule
|
||||
from .rules import atbot, prefix, regexp, random
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import abc
|
|||
|
||||
import mirai
|
||||
|
||||
from ...boot import app
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user