refactor: 将 yirimirai 的组件集成进 platform 包

This commit is contained in:
RockChinQ 2024-09-26 00:23:03 +08:00
parent ee0d6dcdae
commit 1c4a700d92
No known key found for this signature in database
GPG Key ID: 8AC0BEFE1743A015
36 changed files with 1580 additions and 342 deletions

View File

@ -3,10 +3,11 @@ from __future__ import annotations
import typing
import pydantic
import mirai
# import mirai
from ..core import app, entities as core_entities
from . import errors, operator
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel):
@ -17,7 +18,7 @@ class CommandReturn(pydantic.BaseModel):
"""文本
"""
image: typing.Optional[mirai.Image] = None
image: typing.Optional[platform_message.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None

View File

@ -6,13 +6,16 @@ import datetime
import asyncio
import pydantic
import mirai
# import mirai
from ..provider import entities as llm_entities
from ..provider.modelmgr import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
from ..platform.types import entities as platform_entities
class LauncherTypes(enum.Enum):
@ -40,10 +43,10 @@ class Query(pydantic.BaseModel):
sender_id: int
"""发送者IDplatform处理阶段设置"""
message_event: mirai.MessageEvent
message_event: platform_events.MessageEvent
"""事件platform收到的原始事件"""
message_chain: mirai.MessageChain
message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链"""
adapter: msadapter.MessageSourceAdapter
@ -67,10 +70,10 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = []
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======

View File

@ -1,8 +1,8 @@
from __future__ import annotations
import mirai
import mirai.models
import mirai.models.message
# import mirai
# import mirai.models
# import mirai.models.message
from ...core import app
@ -12,6 +12,9 @@ from ...config import manager as cfg_mgr
from . import filter as filter_model, entities as filter_entities
from .filters import cntignore, banwords, baiduexamine
from ...provider import entities as llm_entities
from ...platform.types import message as platform_message
from ...platform.types import events as platform_events
from ...platform.types import entities as platform_entities
@stage.stage_class('PostContentFilterStage')
@ -89,8 +92,8 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement
query.message_chain = mirai.MessageChain(
mirai.Plain(message)
query.message_chain = platform_message.MessageChain(
platform_message.Plain(message)
)
return entities.StageProcessResult(
@ -148,7 +151,7 @@ class ContentFilterStage(stage.PipelineStage):
contain_non_text = False
text_components = [mirai.Plain, mirai.models.message.Source]
text_components = [platform_message.Plain, platform_message.Source]
for me in query.message_chain:
if type(me) not in text_components:

View File

@ -4,11 +4,12 @@ import asyncio
import typing
import traceback
import mirai
# import mirai
from ..core import app, entities
from . import entities as pipeline_entities
from ..plugin import events
from ..platform.types import message as platform_message
class Controller:
@ -73,11 +74,11 @@ class Controller:
# 处理str类型
if isinstance(result.user_notice, str):
result.user_notice = mirai.MessageChain(
mirai.Plain(result.user_notice)
result.user_notice = platform_message.MessageChain(
platform_message.Plain(result.user_notice)
)
elif isinstance(result.user_notice, list):
result.user_notice = mirai.MessageChain(
result.user_notice = platform_message.MessageChain(
*result.user_notice
)

View File

@ -4,8 +4,8 @@ import enum
import typing
import pydantic
import mirai
import mirai.models.message as mirai_message
# import mirai
from ..platform.types import message as platform_message
from ..core import entities
@ -25,7 +25,7 @@ class StageProcessResult(pydantic.BaseModel):
new_query: entities.Query
user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = []
user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = []
"""只要设置了就会发送给用户"""
# TODO delete

View File

@ -3,7 +3,7 @@ import os
import traceback
from PIL import Image, ImageDraw, ImageFont
from mirai.models.message import MessageComponent, Plain, MessageChain
# from mirai.models.message import MessageComponent, Plain, MessageChain
from ...core import app
from . import strategy
@ -11,6 +11,7 @@ from .strategies import image, forward
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...platform.types import message as platform_message
@stage.stage_class("LongTextProcessStage")
@ -63,14 +64,14 @@ class LongTextProcessStage(stage.PipelineStage):
contains_non_plain = False
for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain):
if not isinstance(msg, platform_message.Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@ -2,15 +2,17 @@
from __future__ import annotations
import typing
from mirai.models import MessageChain
from mirai.models.message import MessageComponent, ForwardMessageNode
from mirai.models.base import MiraiBaseModel
# from mirai.models import MessageChain
# from mirai.models.message import MessageComponent, ForwardMessageNode
# from mirai.models.base import MiraiBaseModel
import pydantic
from .. import strategy as strategy_model
from ....core import entities as core_entities
from ....platform.types import message as platform_message
class ForwardMessageDiaplay(MiraiBaseModel):
class ForwardMessageDiaplay(pydantic.BaseModel):
title: str = "群聊的聊天记录"
brief: str = "[聊天记录]"
source: str = "聊天记录"
@ -18,13 +20,13 @@ class ForwardMessageDiaplay(MiraiBaseModel):
summary: str = "查看x条转发消息"
class Forward(MessageComponent):
class Forward(platform_message.MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
display: ForwardMessageDiaplay
"""显示信息"""
node_list: typing.List[ForwardMessageNode]
node_list: typing.List[platform_message.ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
@ -39,7 +41,7 @@ class Forward(MessageComponent):
@strategy_model.strategy_class("forward")
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay(
title="群聊的聊天记录",
brief="[聊天记录]",
@ -49,10 +51,10 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy):
)
node_list = [
ForwardMessageNode(
platform_message.ForwardMessageNode(
sender_id=query.adapter.bot_account_id,
sender_name='QQ用户',
message_chain=MessageChain([message])
message_chain=platform_message.MessageChain([message])
)
]

View File

@ -8,8 +8,9 @@ import re
from PIL import Image, ImageDraw, ImageFont
from mirai.models import MessageChain, Image as ImageComponent
from mirai.models.message import MessageComponent
# from mirai.models import MessageChain, Image as ImageComponent
# from mirai.models.message import MessageComponent
from ....platform.types import message as platform_message
from .. import strategy as strategy_model
from ....core import entities as core_entities
@ -23,7 +24,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image(
text_str=message,
save_as='temp/{}.png'.format(int(time.time()))
@ -46,7 +47,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
os.remove(compressed_path)
return [
ImageComponent(
platform_message.Image(
base64=b64.decode('utf-8'),
)
]

View File

@ -2,11 +2,12 @@ from __future__ import annotations
import abc
import typing
import mirai
from mirai.models.message import MessageComponent
# import mirai
# from mirai.models.message import MessageComponent
from ...core import app
from ...core import entities as core_entities
from ...platform.types import message as platform_message
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@ -51,7 +52,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
"""处理长文本
platform.json 中配置 long-text-process 字段只要 文本长度超过了 threshold 就会调用此方法
@ -61,6 +62,6 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
query (core_entities.Query): 此次请求的上下文对象
Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
list[platform_message.MessageComponent]: 转换后的 平台 消息组件列表
"""
return []

View File

@ -2,10 +2,12 @@ from __future__ import annotations
import asyncio
import mirai
# import mirai
from ..core import entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class QueryPool:
@ -30,8 +32,8 @@ class QueryPool:
launcher_type: entities.LauncherTypes,
launcher_id: int,
sender_id: int,
message_event: mirai.MessageEvent,
message_chain: mirai.MessageChain,
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: msadapter.MessageSourceAdapter
) -> entities.Query:
async with self.condition:

View File

@ -1,11 +1,12 @@
from __future__ import annotations
import mirai
# import mirai
from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("PreProcessor")
@ -55,11 +56,11 @@ class PreProcessor(stage.PipelineStage):
content_list = []
for me in query.message_chain:
if isinstance(me, mirai.Plain):
if isinstance(me, platform_message.Plain):
content_list.append(
llm_entities.ContentElement.from_text(me.text)
)
elif isinstance(me, mirai.Image):
elif isinstance(me, platform_message.Image):
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
if me.url is not None:
content_list.append(

View File

@ -5,7 +5,7 @@ import time
import traceback
import json
import mirai
# import mirai
from .. import handler
from ... import entities
@ -13,6 +13,8 @@ from ....core import entities as core_entities
from ....provider import entities as llm_entities, runnermgr
from ....plugin import events
from ....platform.types import message as platform_message
class ChatMessageHandler(handler.MessageHandler):
@ -40,7 +42,7 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc)

View File

@ -1,13 +1,14 @@
from __future__ import annotations
import typing
import mirai
# import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
from ....platform.types import message as platform_message
class CommandHandler(handler.MessageHandler):
@ -46,7 +47,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
mc = platform_message.MessageChain(event_ctx.event.reply)
query.resp_messages.append(mc)
@ -63,8 +64,8 @@ class CommandHandler(handler.MessageHandler):
else:
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
query.message_chain = platform_message.MessageChain([
platform_message.Plain(event_ctx.event.alter)
])
session = await self.ap.sess_mgr.get_session(query)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import random
import asyncio
import mirai
# import mirai
from ...core import app

View File

@ -1,9 +1,11 @@
import pydantic
import mirai
# import mirai
from ...platform.types import message as platform_message
class RuleJudgeResult(pydantic.BaseModel):
matching: bool = False
replacement: mirai.MessageChain = None
replacement: platform_message.MessageChain = None

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import mirai
# import mirai
from ...core import app
from . import entities as rule_entities, rule

View File

@ -2,11 +2,13 @@ from __future__ import annotations
import abc
import typing
import mirai
# import mirai
from ...core import app, entities as core_entities
from . import entities
from ...platform.types import message as platform_message
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@ -35,7 +37,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:

View File

@ -1,10 +1,11 @@
from __future__ import annotations
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("at-bot")
@ -13,16 +14,16 @@ class AtBotRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(mirai.At(query.adapter.bot_account_id))
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
message_chain.remove(mirai.At(query.adapter.bot_account_id))
if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次检查并删除重复的
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
return entities.RuleJudgeResult(
matching=True,

View File

@ -1,8 +1,9 @@
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("prefix")
@ -11,7 +12,7 @@ class PrefixRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:
@ -22,7 +23,7 @@ class PrefixRule(rule_model.GroupRespondRule):
# 查找第一个plain元素
for me in message_chain:
if isinstance(me, mirai.Plain):
if isinstance(me, platform_message.Plain):
me.text = me.text[len(prefix):]
return entities.RuleJudgeResult(

View File

@ -1,10 +1,11 @@
import random
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("random")
@ -13,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:

View File

@ -1,10 +1,11 @@
import re
import mirai
# import mirai
from .. import rule as rule_model
from .. import entities
from ....core import entities as core_entities
from ....platform.types import message as platform_message
@rule_model.rule_class("regexp")
@ -13,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
async def match(
self,
message_text: str,
message_chain: mirai.MessageChain,
message_chain: platform_message.MessageChain,
rule_dict: dict,
query: core_entities.Query
) -> entities.RuleJudgeResult:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import typing
import mirai
# import mirai
from ...core import app, entities as core_entities
from .. import entities
@ -10,6 +10,7 @@ from .. import stage, entities, stagemgr
from ...core import entities as core_entities
from ...config import manager as cfg_mgr
from ...plugin import events
from ...platform.types import message as platform_message
@stage.stage_class("ResponseWrapper")
@ -34,7 +35,7 @@ class ResponseWrapper(stage.PipelineStage):
"""
# 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], mirai.MessageChain):
if isinstance(query.resp_messages[-1], platform_message.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
@ -96,7 +97,7 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else:
@ -113,7 +114,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
@ -139,11 +140,11 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@ -4,9 +4,11 @@ from __future__ import annotations
import typing
import abc
import mirai
# import mirai
from ..core import app
from .types import message as platform_message
from .types import events as platform_events
preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
@ -55,7 +57,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
message: platform_message.MessageChain
):
"""主动发送消息
@ -68,8 +70,8 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False
):
"""回复消息
@ -87,8 +89,8 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
):
"""注册事件监听器
@ -100,8 +102,8 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None]
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None]
):
"""注销事件监听器
@ -127,7 +129,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
class MessageConverter:
"""消息链转换器基类"""
@staticmethod
def yiri2target(message_chain: mirai.MessageChain):
def yiri2target(message_chain: platform_message.MessageChain):
"""将YiriMirai消息链转换为目标消息链
Args:
@ -139,7 +141,7 @@ class MessageConverter:
raise NotImplementedError
@staticmethod
def target2yiri(message_chain: typing.Any) -> mirai.MessageChain:
def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
"""将目标消息链转换为YiriMirai消息链
Args:
@ -155,7 +157,7 @@ class EventConverter:
"""事件转换器基类"""
@staticmethod
def yiri2target(event: typing.Type[mirai.Event]):
def yiri2target(event: typing.Type[platform_message.Event]):
"""将YiriMirai事件转换为目标事件
Args:
@ -167,7 +169,7 @@ class EventConverter:
raise NotImplementedError
@staticmethod
def target2yiri(event: typing.Any) -> mirai.Event:
def target2yiri(event: typing.Any) -> platform_message.Event:
"""将目标事件的调用参数转换为YiriMirai的事件参数对象
Args:

View File

@ -6,13 +6,17 @@ import logging
import asyncio
import traceback
from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
FriendMessage, Image, MessageChain, Plain
import mirai
# from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \
# FriendMessage, Image, MessageChain, Plain
# import mirai
from ..platform import adapter as msadapter
from ..core import app, entities as core_entities
from ..plugin import events
from .types import message as platform_message
from .types import events as platform_events
from .types import entities as platform_entities
# 控制QQ消息输入输出的类
class PlatformManager:
@ -32,7 +36,7 @@ class PlatformManager:
from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy
async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter):
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
@ -55,7 +59,7 @@ class PlatformManager:
adapter=adapter
)
async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter):
async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.PersonMessageReceived(
@ -78,7 +82,7 @@ class PlatformManager:
adapter=adapter
)
async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter):
async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter):
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.GroupMessageReceived(
@ -127,16 +131,16 @@ class PlatformManager:
if adapter_name == 'yiri-mirai':
adapter_inst.register_listener(
StrangerMessage,
platform_events.StrangerMessage,
on_stranger_message
)
adapter_inst.register_listener(
FriendMessage,
platform_events.FriendMessage,
on_friend_message
)
adapter_inst.register_listener(
GroupMessage,
platform_events.GroupMessage,
on_group_message
)
@ -146,13 +150,13 @@ class PlatformManager:
if len(self.adapters) == 0:
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter):
async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessageSourceAdapter):
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage):
msg.insert(
0,
At(
platform_message.At(
event.sender.id
)
)

View File

@ -5,31 +5,34 @@ import traceback
import time
import datetime
import mirai
import mirai.models.message as yiri_message
# import mirai
# import mirai.models.message as yiri_message
import aiocqhttp
from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod
def yiri2target(message_chain: mirai.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]:
msg_list = aiocqhttp.Message()
msg_id = 0
msg_time = None
for msg in message_chain:
if type(msg) is mirai.Plain:
if type(msg) is platform_message.Plain:
msg_list.append(aiocqhttp.MessageSegment.text(msg.text))
elif type(msg) is yiri_message.Source:
elif type(msg) is platform_message.Source:
msg_id = msg.id
msg_time = msg.time
elif type(msg) is mirai.Image:
elif type(msg) is platform_message.Image:
arg = ''
if msg.base64:
arg = msg.base64
@ -40,13 +43,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.path:
arg = msg.path
msg_list.append(aiocqhttp.MessageSegment.image(arg))
elif type(msg) is mirai.At:
elif type(msg) is platform_message.At:
msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
elif type(msg) is mirai.AtAll:
elif type(msg) is platform_message.AtAll:
msg_list.append(aiocqhttp.MessageSegment.at("all"))
elif type(msg) is mirai.Face:
msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id))
elif type(msg) is mirai.Voice:
elif type(msg) is platform_message.Voice:
arg = ''
if msg.base64:
arg = msg.base64
@ -74,25 +75,25 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
yiri_msg_list = []
yiri_msg_list.append(
yiri_message.Source(id=message_id, time=datetime.datetime.now())
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
for msg in message:
if msg.type == "at":
if msg.data["qq"] == "all":
yiri_msg_list.append(yiri_message.AtAll())
yiri_msg_list.append(platform_message.AtAll())
else:
yiri_msg_list.append(
yiri_message.At(
platform_message.At(
target=msg.data["qq"],
)
)
elif msg.type == "text":
yiri_msg_list.append(yiri_message.Plain(text=msg.data["text"]))
yiri_msg_list.append(platform_message.Plain(text=msg.data["text"]))
elif msg.type == "image":
yiri_msg_list.append(yiri_message.Image(url=msg.data["url"]))
yiri_msg_list.append(platform_message.Image(url=msg.data["url"]))
chain = mirai.MessageChain(yiri_msg_list)
chain = platform_message.MessageChain(yiri_msg_list)
return chain
@ -100,11 +101,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod
def yiri2target(event: mirai.Event, bot_account_id: int):
def yiri2target(event: platform_events.Event, bot_account_id: int):
msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain)
if type(event) is mirai.GroupMessage:
if type(event) is platform_events.GroupMessage:
role = "member"
if event.sender.permission == "ADMINISTRATOR":
@ -140,7 +141,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
}
return aiocqhttp.Event.from_payload(payload)
elif type(event) is mirai.FriendMessage:
elif type(event) is platform_events.FriendMessage:
payload = {
"post_type": "message",
@ -177,15 +178,15 @@ class AiocqhttpEventConverter(adapter.EventConverter):
permission = "ADMINISTRATOR"
elif event.sender["role"] == "owner":
permission = "OWNER"
converted_event = mirai.GroupMessage(
sender=mirai.models.entities.GroupMember(
converted_event = platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=event.sender["user_id"], # message_seq 放哪?
member_name=event.sender["nickname"],
permission=permission,
group=mirai.models.entities.Group(
group=platform_entities.Group(
id=event.group_id,
name=event.sender["nickname"],
permission=mirai.models.entities.Permission.Member,
permission=platform_entities.Permission.Member,
),
special_title=event.sender["title"] if "title" in event.sender else "",
join_timestamp=0,
@ -197,8 +198,8 @@ class AiocqhttpEventConverter(adapter.EventConverter):
)
return converted_event
elif event.message_type == "private":
return mirai.FriendMessage(
sender=mirai.models.entities.Friend(
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.sender["user_id"],
nickname=event.sender["nickname"],
remark="",
@ -240,7 +241,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
self.bot = aiocqhttp.CQHttp()
async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
@ -251,8 +252,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
@ -270,8 +271,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
):
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
@ -280,15 +281,15 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
except:
traceback.print_exc()
if event_type == mirai.GroupMessage:
if event_type == platform_events.GroupMessage:
self.bot.on_message("group")(on_message)
elif event_type == mirai.FriendMessage:
elif event_type == platform_events.FriendMessage:
self.bot.on_message("private")(on_message)
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None],
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@ -6,26 +6,29 @@ import typing
import traceback
import logging
import mirai
# import mirai
import nakuru
import nakuru.entities.components as nkc
from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward
from ...platform.types import message as platform_message
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
"""消息转换器"""
@staticmethod
def yiri2target(message_chain: mirai.MessageChain) -> list:
def yiri2target(message_chain: platform_message.MessageChain) -> list:
msg_list = []
if type(message_chain) is mirai.MessageChain:
if type(message_chain) is platform_message.MessageChain:
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(message_chain)]
msg_list = [platform_message.Plain(message_chain)]
else:
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
@ -33,22 +36,20 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
# 遍历并转换
for component in msg_list:
if type(component) is mirai.Plain:
if type(component) is platform_message.Plain:
nakuru_msg_list.append(nkc.Plain(component.text, False))
elif type(component) is mirai.Image:
elif type(component) is platform_message.Image:
if component.url is not None:
nakuru_msg_list.append(nkc.Image.fromURL(component.url))
elif component.base64 is not None:
nakuru_msg_list.append(nkc.Image.fromBase64(component.base64))
elif component.path is not None:
nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path))
elif type(component) is mirai.Face:
nakuru_msg_list.append(nkc.Face(id=component.face_id))
elif type(component) is mirai.At:
elif type(component) is platform_message.At:
nakuru_msg_list.append(nkc.At(qq=component.target))
elif type(component) is mirai.AtAll:
elif type(component) is platform_message.AtAll:
nakuru_msg_list.append(nkc.AtAll())
elif type(component) is mirai.Voice:
elif type(component) is platform_message.Voice:
if component.url is not None:
nakuru_msg_list.append(nkc.Record.fromURL(component.url))
elif component.path is not None:
@ -80,49 +81,47 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
return nakuru_msg_list
@staticmethod
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.MessageChain:
def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain:
"""将Yiri的消息链转换为YiriMirai的消息链"""
assert type(message_chain) is list
yiri_msg_list = []
import datetime
# 添加Source组件以标记message_id等信息
yiri_msg_list.append(mirai.models.message.Source(id=message_id, time=datetime.datetime.now()))
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
for component in message_chain:
if type(component) is nkc.Plain:
yiri_msg_list.append(mirai.Plain(text=component.text))
yiri_msg_list.append(platform_message.Plain(text=component.text))
elif type(component) is nkc.Image:
yiri_msg_list.append(mirai.Image(url=component.url))
elif type(component) is nkc.Face:
yiri_msg_list.append(mirai.Face(face_id=component.id))
yiri_msg_list.append(platform_message.Image(url=component.url))
elif type(component) is nkc.At:
yiri_msg_list.append(mirai.At(target=component.qq))
yiri_msg_list.append(platform_message.At(target=component.qq))
elif type(component) is nkc.AtAll:
yiri_msg_list.append(mirai.AtAll())
yiri_msg_list.append(platform_message.AtAll())
else:
pass
# logging.debug("转换后的消息链: " + str(yiri_msg_list))
chain = mirai.MessageChain(yiri_msg_list)
chain = platform_message.MessageChain(yiri_msg_list)
return chain
class NakuruProjectEventConverter(adapter_model.EventConverter):
"""事件转换器"""
@staticmethod
def yiri2target(event: typing.Type[mirai.Event]):
if event is mirai.GroupMessage:
def yiri2target(event: typing.Type[platform_events.Event]):
if event is platform_events.GroupMessage:
return nakuru.GroupMessage
elif event is mirai.FriendMessage:
elif event is platform_events.FriendMessage:
return nakuru.FriendMessage
else:
raise Exception("未支持转换的事件类型: " + str(event))
@staticmethod
def target2yiri(event: typing.Any) -> mirai.Event:
def target2yiri(event: typing.Any) -> platform_events.Event:
yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id)
if type(event) is nakuru.FriendMessage: # 私聊消息事件
return mirai.FriendMessage(
sender=mirai.models.entities.Friend(
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.sender.user_id,
nickname=event.sender.nickname,
remark=event.sender.nickname
@ -138,16 +137,15 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
elif event.sender.role == "owner":
permission = "OWNER"
import mirai.models.entities as entities
return mirai.GroupMessage(
sender=mirai.models.entities.GroupMember(
return platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=event.sender.user_id,
member_name=event.sender.nickname,
permission=permission,
group=mirai.models.entities.Group(
group=platform_entities.Group(
id=event.group_id,
name=event.sender.nickname,
permission=entities.Permission.Member
permission=platform_entities.Permission.Member
),
special_title=event.sender.title,
join_timestamp=0,
@ -189,7 +187,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
self,
target_type: str,
target_id: str,
message: typing.Union[mirai.MessageChain, list],
message: typing.Union[platform_message.MessageChain, list],
converted: bool = False
):
task = None
@ -222,8 +220,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False
):
message = self.message_converter.yiri2target(message)
@ -233,14 +231,14 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
id=message_source.message_chain.message_id,
)
)
if type(message_source) is mirai.GroupMessage:
if type(message_source) is platform_events.GroupMessage:
await self.send_message(
"group",
message_source.sender.group.id,
message,
converted=True
)
elif type(message_source) is mirai.FriendMessage:
elif type(message_source) is platform_events.FriendMessage:
await self.send_message(
"person",
message_source.sender.id,
@ -258,8 +256,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
):
try:
@ -286,8 +284,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None]
):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__

View File

@ -6,7 +6,7 @@ import datetime
import re
import traceback
import mirai
# import mirai
import botpy
import botpy.message as botpy_message
import botpy.types.message as botpy_message_type
@ -17,17 +17,21 @@ from .. import adapter as adapter_model
from ...pipeline.longtext.strategies import forward
from ...core import app
from ...config import manager as cfg_mgr
from ...platform.types import entities as platform_entities
from ...platform.types import events as platform_events
from ...platform.types import message as platform_message
class OfficialGroupMessage(mirai.GroupMessage):
class OfficialGroupMessage(platform_events.GroupMessage):
pass
class OfficialFriendMessage(mirai.FriendMessage):
class OfficialFriendMessage(platform_events.FriendMessage):
pass
event_handler_mapping = {
mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
mirai.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
platform_events.GroupMessage: ["on_at_message_create", "on_group_at_message_create"],
platform_events.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"],
}
@ -123,16 +127,16 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
"""QQ 官方消息转换器"""
@staticmethod
def yiri2target(message_chain: mirai.MessageChain):
def yiri2target(message_chain: platform_message.MessageChain):
"""将 YiriMirai 的消息链转换为 QQ 官方消息"""
msg_list = []
if type(message_chain) is mirai.MessageChain:
if type(message_chain) is platform_message.MessageChain:
msg_list = message_chain.__root__
elif type(message_chain) is list:
msg_list = message_chain
elif type(message_chain) is str:
msg_list = [mirai.Plain(text=message_chain)]
msg_list = [platform_message.Plain(text=message_chain)]
else:
raise Exception(
"Unknown message type: " + str(message_chain) + str(type(message_chain))
@ -153,22 +157,22 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
# 遍历并转换
for component in msg_list:
if type(component) is mirai.Plain:
if type(component) is platform_message.Plain:
offcial_messages.append({"type": "text", "content": component.text})
elif type(component) is mirai.Image:
elif type(component) is platform_message.Image:
if component.url is not None:
offcial_messages.append({"type": "image", "content": component.url})
elif component.path is not None:
offcial_messages.append(
{"type": "file_image", "content": component.path}
)
elif type(component) is mirai.At:
elif type(component) is platform_message.At:
offcial_messages.append({"type": "at", "content": ""})
elif type(component) is mirai.AtAll:
elif type(component) is platform_message.AtAll:
print(
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
elif type(component) is mirai.Voice:
elif type(component) is platform_message.Voice:
print(
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
)
@ -197,29 +201,29 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
message_id: str = None,
bot_account_id: int = 0,
) -> mirai.MessageChain:
) -> platform_message.MessageChain:
yiri_msg_list = []
# 存id
yiri_msg_list.append(
mirai.models.message.Source(
platform_message.Source(
id=save_msg_id(message_id), time=datetime.datetime.now()
)
)
if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]:
yiri_msg_list.append(mirai.At(target=bot_account_id))
yiri_msg_list.append(platform_message.At(target=bot_account_id))
if hasattr(message, "mentions"):
for mention in message.mentions:
if mention.bot:
continue
yiri_msg_list.append(mirai.At(target=mention.id))
yiri_msg_list.append(platform_message.At(target=mention.id))
for attachment in message.attachments:
if attachment.content_type.startswith("image"):
yiri_msg_list.append(mirai.Image(url=attachment.url))
yiri_msg_list.append(platform_message.Image(url=attachment.url))
else:
logging.warning(
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
@ -227,9 +231,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
content = re.sub(r"<@!\d+>", "", str(message.content))
if content.strip() != "":
yiri_msg_list.append(mirai.Plain(text=content))
yiri_msg_list.append(platform_message.Plain(text=content))
chain = mirai.MessageChain(yiri_msg_list)
chain = platform_message.MessageChain(yiri_msg_list)
return chain
@ -244,10 +248,10 @@ class OfficialEventConverter(adapter_model.EventConverter):
self.member_openid_mapping = member_openid_mapping
self.group_openid_mapping = group_openid_mapping
def yiri2target(self, event: typing.Type[mirai.Event]):
if event == mirai.GroupMessage:
def yiri2target(self, event: typing.Type[platform_events.Event]):
if event == platform_events.GroupMessage:
return botpy_message.Message
elif event == mirai.FriendMessage:
elif event == platform_events.FriendMessage:
return botpy_message.DirectMessage
else:
raise Exception(
@ -257,8 +261,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
def target2yiri(
self,
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage],
) -> mirai.Event:
import mirai.models.entities as mirai_entities
) -> platform_events.Event:
if type(event) == botpy_message.Message: # 频道内,转群聊事件
permission = "MEMBER"
@ -268,15 +271,15 @@ class OfficialEventConverter(adapter_model.EventConverter):
elif "4" in event.member.roles:
permission = "OWNER"
return mirai.GroupMessage(
sender=mirai_entities.GroupMember(
return platform_events.GroupMessage(
sender=platform_entities.GroupMember(
id=event.author.id,
member_name=event.author.username,
permission=permission,
group=mirai_entities.Group(
group=platform_entities.Group(
id=event.channel_id,
name=event.author.username,
permission=mirai_entities.Permission.Member,
permission=platform_entities.Permission.Member,
),
special_title="",
join_timestamp=int(
@ -297,8 +300,8 @@ class OfficialEventConverter(adapter_model.EventConverter):
),
)
elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件
return mirai.FriendMessage(
sender=mirai_entities.Friend(
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.guild_id,
nickname=event.author.username,
remark=event.author.username,
@ -317,14 +320,14 @@ class OfficialEventConverter(adapter_model.EventConverter):
replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
return OfficialGroupMessage(
sender=mirai_entities.GroupMember(
sender=platform_entities.GroupMember(
id=replacing_member_id,
member_name=replacing_member_id,
permission="MEMBER",
group=mirai_entities.Group(
group=platform_entities.Group(
id=self.group_openid_mapping.save_openid(event.group_openid),
name=replacing_member_id,
permission=mirai_entities.Permission.Member,
permission=platform_entities.Permission.Member,
),
special_title="",
join_timestamp=int(0),
@ -345,7 +348,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
user_id_alter = self.member_openid_mapping.save_openid(event.author.user_openid) # 实测这里的user_openid与group的member_openid是一样的
return OfficialFriendMessage(
sender=mirai_entities.Friend(
sender=platform_entities.Friend(
id=user_id_alter,
nickname=user_id_alter,
remark=user_id_alter,
@ -410,7 +413,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
self.bot = botpy.Client(intents=intents)
async def send_message(
self, target_type: str, target_id: str, message: mirai.MessageChain
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
message_list = self.message_converter.yiri2target(message)
@ -437,8 +440,8 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
@ -463,13 +466,13 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
]
)
if type(message_source) == mirai.GroupMessage:
if type(message_source) == platform_events.GroupMessage:
args["channel_id"] = str(message_source.sender.group.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
]
await self.bot.api.post_message(**args)
elif type(message_source) == mirai.FriendMessage:
elif type(message_source) == platform_events.FriendMessage:
args["guild_id"] = str(message_source.sender.id)
args["msg_id"] = cached_message_ids[
str(message_source.message_chain.message_id)
@ -534,9 +537,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def register_listener(
self,
event_type: typing.Type[mirai.Event],
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
[platform_events.Event, adapter_model.MessageSourceAdapter], None
],
):
@ -560,9 +563,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[mirai.Event, adapter_model.MessageSourceAdapter], None
[platform_events.Event, adapter_model.MessageSourceAdapter], None
],
):
delattr(self.bot, event_handler_mapping[event_type])

View File

@ -1,124 +1,124 @@
import asyncio
import typing
# import asyncio
# import typing
import mirai
import mirai.models.bus
from mirai.bot import MiraiRunner
# import mirai
# import mirai.models.bus
# from mirai.bot import MiraiRunner
from .. import adapter as adapter_model
from ...core import app
# from .. import adapter as adapter_model
# from ...core import app
@adapter_model.adapter_class("yiri-mirai")
class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
"""YiriMirai适配器"""
bot: mirai.Mirai
# @adapter_model.adapter_class("yiri-mirai")
# class YiriMiraiAdapter(adapter_model.MessageSourceAdapter):
# """YiriMirai适配器"""
# bot: mirai.Mirai
def __init__(self, config: dict, ap: app.Application):
"""初始化YiriMirai的对象"""
self.ap = ap
self.config = config
if 'adapter' not in config or \
config['adapter'] == 'WebSocketAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.WebSocketAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
elif config['adapter'] == 'HTTPAdapter':
self.bot = mirai.Mirai(
qq=config['qq'],
adapter=mirai.HTTPAdapter(
host=config['host'],
port=config['port'],
verify_key=config['verifyKey']
)
)
else:
raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
# def __init__(self, config: dict, ap: app.Application):
# """初始化YiriMirai的对象"""
# self.ap = ap
# self.config = config
# if 'adapter' not in config or \
# config['adapter'] == 'WebSocketAdapter':
# self.bot = mirai.Mirai(
# qq=config['qq'],
# adapter=mirai.WebSocketAdapter(
# host=config['host'],
# port=config['port'],
# verify_key=config['verifyKey']
# )
# )
# elif config['adapter'] == 'HTTPAdapter':
# self.bot = mirai.Mirai(
# qq=config['qq'],
# adapter=mirai.HTTPAdapter(
# host=config['host'],
# port=config['port'],
# verify_key=config['verifyKey']
# )
# )
# else:
# raise Exception('Unknown adapter for YiriMirai: ' + config['adapter'])
async def send_message(
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
):
"""发送消息
# async def send_message(
# self,
# target_type: str,
# target_id: str,
# message: mirai.MessageChain
# ):
# """发送消息
Args:
target_type (str): 目标类型`person``group`
target_id (str): 目标ID
message (mirai.MessageChain): YiriMirai库的消息链
"""
task = None
if target_type == 'person':
task = self.bot.send_friend_message(int(target_id), message)
elif target_type == 'group':
task = self.bot.send_group_message(int(target_id), message)
else:
raise Exception('Unknown target type: ' + target_type)
# Args:
# target_type (str): 目标类型,`person`或`group`
# target_id (str): 目标ID
# message (mirai.MessageChain): YiriMirai库的消息链
# """
# task = None
# if target_type == 'person':
# task = self.bot.send_friend_message(int(target_id), message)
# elif target_type == 'group':
# task = self.bot.send_group_message(int(target_id), message)
# else:
# raise Exception('Unknown target type: ' + target_type)
await task
# await task
async def reply_message(
self,
message_source: mirai.MessageEvent,
message: mirai.MessageChain,
quote_origin: bool = False
):
"""回复消息
# async def reply_message(
# self,
# message_source: mirai.MessageEvent,
# message: mirai.MessageChain,
# quote_origin: bool = False
# ):
# """回复消息
Args:
message_source (mirai.MessageEvent): YiriMirai消息源事件
message (mirai.MessageChain): YiriMirai库的消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
await self.bot.send(message_source, message, quote_origin)
# Args:
# message_source (mirai.MessageEvent): YiriMirai消息源事件
# message (mirai.MessageChain): YiriMirai库的消息链
# quote_origin (bool, optional): 是否引用原消息. Defaults to False.
# """
# await self.bot.send(message_source, message, quote_origin)
async def is_muted(self, group_id: int) -> bool:
result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
if result.mute_time_remaining > 0:
return True
return False
# async def is_muted(self, group_id: int) -> bool:
# result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get()
# if result.mute_time_remaining > 0:
# return True
# return False
def register_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
"""注册事件监听器
# def register_listener(
# self,
# event_type: typing.Type[mirai.Event],
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
# ):
# """注册事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
async def wrapper(event: mirai.Event):
await callback(event, self)
self.bot.on(event_type)(wrapper)
# Args:
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
# callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
# """
# async def wrapper(event: mirai.Event):
# await callback(event, self)
# self.bot.on(event_type)(wrapper)
def unregister_listener(
self,
event_type: typing.Type[mirai.Event],
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
):
"""注销事件监听器
# def unregister_listener(
# self,
# event_type: typing.Type[mirai.Event],
# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
# ):
# """注销事件监听器
Args:
event_type (typing.Type[mirai.Event]): YiriMirai事件类型
callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
"""
assert isinstance(self.bot, mirai.Mirai)
bus = self.bot.bus
assert isinstance(bus, mirai.models.bus.ModelEventBus)
# Args:
# event_type (typing.Type[mirai.Event]): YiriMirai事件类型
# callback (typing.Callable[[mirai.Event], None]): 回调函数接收一个参数为YiriMirai事件
# """
# assert isinstance(self.bot, mirai.Mirai)
# bus = self.bot.bus
# assert isinstance(bus, mirai.models.bus.ModelEventBus)
bus.unsubscribe(event_type, callback)
# bus.unsubscribe(event_type, callback)
async def run_async(self):
self.bot_account_id = self.bot.qq
return await MiraiRunner(self.bot)._run()
# async def run_async(self):
# self.bot_account_id = self.bot.qq
# return await MiraiRunner(self.bot)._run()
async def kill(self) -> bool:
return False
# async def kill(self) -> bool:
# return False

View File

105
pkg/platform/types/base.py Normal file
View File

@ -0,0 +1,105 @@
from typing import Dict, List, Type
import pydantic.main as pdm
from pydantic import BaseModel
class PlatformMetaclass(pdm.ModelMetaclass):
"""此类是 YiriMirai 中使用的 pydantic 模型的元类的基类。"""
def to_camel(name: str) -> str:
"""将下划线命名风格转换为小驼峰命名。"""
if name[:2] == '__': # 不处理双下划线开头的特殊命名。
return name
name_parts = name.split('_')
return ''.join(name_parts[:1] + [x.title() for x in name_parts[1:]])
class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass):
"""模型基类。
启用了三项配置
1. 允许解析时传入额外的值并将额外值保存在模型中
2. 允许通过别名访问字段
3. 自动生成小驼峰风格的别名以符合 mirai-api-http 的命名
"""
def __init__(self, *args, **kwargs):
""""""
super().__init__(*args, **kwargs)
def __repr__(self) -> str:
return self.__class__.__name__ + '(' + ', '.join(
(f'{k}={repr(v)}' for k, v in self.__dict__.items() if v)
) + ')'
class Config:
extra = 'allow'
allow_population_by_field_name = True
alias_generator = to_camel
class PlatformIndexedMetaclass(PlatformMetaclass):
"""可以通过子类名获取子类的类的元类。"""
__indexedbases__: List[Type['PlatformIndexedModel']] = []
__indexedmodel__ = None
def __new__(cls, name, bases, attrs, **kwargs):
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
# 第一类MiraiIndexedModel
if name == 'PlatformIndexedModel':
cls.__indexedmodel__ = new_cls
new_cls.__indexes__ = {}
return new_cls
# 第二类MiraiIndexedModel 的直接子类,这些是可以通过子类名获取子类的类。
if cls.__indexedmodel__ in bases:
cls.__indexedbases__.append(new_cls)
new_cls.__indexes__ = {}
return new_cls
# 第三类MiraiIndexedModel 的直接子类的子类,这些添加到直接子类的索引中。
for base in cls.__indexedbases__:
if issubclass(new_cls, base):
base.__indexes__[name] = new_cls
return new_cls
def __getitem__(cls, name):
return cls.get_subtype(name)
class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass):
"""可以通过子类名获取子类的类。"""
__indexes__: Dict[str, Type['PlatformIndexedModel']]
@classmethod
def get_subtype(cls, name: str) -> Type['PlatformIndexedModel']:
"""根据类名称,获取相应的子类类型。
Args:
name: 类名称
Returns:
Type['MiraiIndexedModel']: 子类类型
"""
try:
type_ = cls.__indexes__.get(name)
if not (type_ and issubclass(type_, cls)):
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!')
return type_
except AttributeError as e:
raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None
@classmethod
def parse_subtype(cls, obj: dict) -> 'PlatformIndexedModel':
"""通过字典,构造对应的模型对象。
Args:
obj: 一个字典包含了模型对象的属性
Returns:
MiraiIndexedModel: 构造的对象
"""
if cls in PlatformIndexedModel.__subclasses__():
ModelType = cls.get_subtype(obj['type'])
return ModelType.parse_obj(obj)
return super().parse_obj(obj)

View File

@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
"""
此模块提供实体和配置项模型
"""
import abc
from datetime import datetime
from enum import Enum
import typing
import pydantic
class Entity(pydantic.BaseModel):
"""实体,表示一个用户或群。"""
id: int
"""QQ 号或群号。"""
@abc.abstractmethod
def get_avatar_url(self) -> str:
"""头像图片链接。"""
@abc.abstractmethod
def get_name(self) -> str:
"""名称。"""
class Friend(Entity):
"""好友。"""
id: int
"""QQ 号。"""
nickname: typing.Optional[str]
"""昵称。"""
remark: typing.Optional[str]
"""备注。"""
def get_avatar_url(self) -> str:
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
def get_name(self) -> str:
return self.nickname or self.remark or ''
class Permission(str, Enum):
"""群成员身份权限。"""
Member = "MEMBER"
"""成员。"""
Administrator = "ADMINISTRATOR"
"""管理员。"""
Owner = "OWNER"
"""群主。"""
def __repr__(self) -> str:
return repr(self.value)
class Group(Entity):
"""群。"""
id: int
"""群号。"""
name: str
"""群名称。"""
permission: Permission
"""Bot 在群中的权限。"""
def get_avatar_url(self) -> str:
return f'https://p.qlogo.cn/gh/{self.id}/{self.id}/'
def get_name(self) -> str:
return self.name
class GroupMember(Entity):
"""群成员。"""
id: int
"""QQ 号。"""
member_name: str
"""群成员名称。"""
permission: Permission
"""Bot 在群中的权限。"""
group: Group
"""群。"""
special_title: str = ''
"""群头衔。"""
join_timestamp: datetime = datetime.utcfromtimestamp(0)
"""加入群的时间。"""
last_speak_timestamp: datetime = datetime.utcfromtimestamp(0)
"""最后一次发言的时间。"""
mute_time_remaining: int = 0
"""禁言剩余时间。"""
def get_avatar_url(self) -> str:
return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140'
def get_name(self) -> str:
return self.member_name
class Client(Entity):
"""来自其他客户端的用户。"""
id: int
"""识别 id。"""
platform: str
"""来源平台。"""
def get_avatar_url(self) -> str:
raise NotImplementedError
def get_name(self) -> str:
return self.platform
class Subject(pydantic.BaseModel):
"""另一种实体类型表示。"""
id: int
"""QQ 号或群号。"""
kind: typing.Literal['Friend', 'Group', 'Stranger']
"""类型。"""
class Config(pydantic.BaseModel):
"""配置项类型。"""
def modify(self, **kwargs) -> 'Config':
"""修改部分设置。"""
for k, v in kwargs.items():
if k in self.__fields__:
setattr(self, k, v)
else:
raise ValueError(f'未知配置项: {k}')
return self
class GroupConfigModel(Config):
"""群配置。"""
name: str
"""群名称。"""
confess_talk: bool
"""是否允许坦白说。"""
allow_member_invite: bool
"""是否允许成员邀请好友入群。"""
auto_approve: bool
"""是否开启自动审批入群。"""
anonymous_chat: bool
"""是否开启匿名聊天。"""
announcement: str = ''
"""群公告。"""
class MemberInfoModel(Config, GroupMember):
"""群成员信息。"""

View File

@ -0,0 +1,124 @@
# -*- coding: utf-8 -*-
"""
此模块提供事件模型
"""
from datetime import datetime
from enum import Enum
import typing
import pydantic
from . import entities as platform_entities
from . import message as platform_message
class Event(pydantic.BaseModel):
"""事件基类。
Args:
type: 事件名
"""
type: str
"""事件名。"""
def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items() if k != 'type' and v
)
) + ')'
@classmethod
def parse_subtype(cls, obj: dict) -> 'Event':
try:
return typing.cast(Event, super().parse_subtype(obj))
except ValueError:
return Event(type=obj['type'])
@classmethod
def get_subtype(cls, name: str) -> typing.Type['Event']:
try:
return typing.cast(typing.Type[Event], super().get_subtype(name))
except ValueError:
return Event
###############################
# Bot Event
class BotEvent(Event):
"""Bot 自身事件。
Args:
type: 事件名
qq: Bot QQ
"""
type: str
"""事件名。"""
qq: int
"""Bot 的 QQ 号。"""
###############################
# Message Event
class MessageEvent(Event):
"""消息事件。
Args:
type: 事件名
message_chain: 消息内容
"""
type: str
"""事件名。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
class FriendMessage(MessageEvent):
"""好友消息。
Args:
type: 事件名
sender: 发送消息的好友
message_chain: 消息内容
"""
type: str = 'FriendMessage'
"""事件名。"""
sender: platform_entities.Friend
"""发送消息的好友。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
class GroupMessage(MessageEvent):
"""群消息。
Args:
type: 事件名
sender: 发送消息的群成员
message_chain: 消息内容
"""
type: str = 'GroupMessage'
"""事件名。"""
sender: platform_entities.GroupMember
"""发送消息的群成员。"""
message_chain: platform_message.MessageChain
"""消息内容。"""
@property
def group(self) -> platform_entities.Group:
return self.sender.group
class StrangerMessage(MessageEvent):
"""陌生人消息。
Args:
type: 事件名
sender: 发送消息的人
message_chain: 消息内容
"""
type: str = 'StrangerMessage'
"""事件名。"""
sender: platform_entities.Friend
"""发送消息的人。"""
message_chain: platform_message.MessageChain
"""消息内容。"""

View File

@ -0,0 +1,826 @@
import itertools
import logging
from datetime import datetime
from enum import Enum
from pathlib import Path
import typing
import pydantic
import pydantic.main
from . import entities as platform_entities
from .base import PlatformBaseModel, PlatformIndexedMetaclass, PlatformIndexedModel
logger = logging.getLogger(__name__)
class MessageComponentMetaclass(PlatformIndexedMetaclass):
"""消息组件元类。"""
__message_component__ = None
def __new__(cls, name, bases, attrs, **kwargs):
new_cls = super().__new__(cls, name, bases, attrs, **kwargs)
if name == 'MessageComponent':
cls.__message_component__ = new_cls
if not cls.__message_component__:
return new_cls
for base in bases:
if issubclass(base, cls.__message_component__):
# 获取字段名
if hasattr(new_cls, '__fields__'):
# 忽略 type 字段
new_cls.__parameter_names__ = list(new_cls.__fields__)[1:]
else:
new_cls.__parameter_names__ = []
break
return new_cls
class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass):
"""消息组件。"""
type: str
"""消息组件类型。"""
def __str__(self):
return ''
def __repr__(self):
return self.__class__.__name__ + '(' + ', '.join(
(
f'{k}={repr(v)}'
for k, v in self.__dict__.items() if k != 'type' and v
)
) + ')'
def __init__(self, *args, **kwargs):
# 解析参数列表,将位置参数转化为具名参数
parameter_names = self.__parameter_names__
if len(args) > len(parameter_names):
raise TypeError(
f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。'
)
for name, value in zip(parameter_names, args):
if name in kwargs:
raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。')
kwargs[name] = value
super().__init__(**kwargs)
TMessageComponent = typing.TypeVar('TMessageComponent', bound=MessageComponent)
class MessageChain(PlatformBaseModel):
"""消息链。
一个构造消息链的例子
```py
message_chain = MessageChain([
AtAll(),
Plain("Hello World!"),
])
```
`Plain` 可以省略
```py
message_chain = MessageChain([
AtAll(),
"Hello World!",
])
```
在调用 API 参数中需要 MessageChain 也可以使用 `List[MessageComponent]` 代替
例如以下两种写法是等价的
```py
await bot.send_friend_message(12345678, [
Plain("Hello World!")
])
```
```py
await bot.send_friend_message(12345678, MessageChain([
Plain("Hello World!")
]))
```
可以使用 `in` 运算检查消息链中
1. 是否有某个消息组件
2. 是否有某个类型的消息组件
```py
if AtAll in message_chain:
print('AtAll')
if At(bot.qq) in message_chain:
print('At Me')
```
消息链对索引操作进行了增强以消息组件类型为索引获取消息链中的全部该类型的消息组件
```py
plain_list = message_chain[Plain]
'[Plain("Hello World!")]'
```
可以用加号连接两个消息链
```py
MessageChain(['Hello World!']) + MessageChain(['Goodbye World!'])
# 返回 MessageChain([Plain("Hello World!"), Plain("Goodbye World!")])
```
"""
__root__: typing.List[MessageComponent]
@staticmethod
def _parse_message_chain(msg_chain: typing.Iterable):
result = []
for msg in msg_chain:
if isinstance(msg, dict):
result.append(MessageComponent.parse_subtype(msg))
elif isinstance(msg, MessageComponent):
result.append(msg)
elif isinstance(msg, str):
result.append(Plain(msg))
else:
raise TypeError(
f"消息链中元素需为 dict 或 str 或 MessageComponent当前类型{type(msg)}"
)
return result
@pydantic.validator('__root__', always=True, pre=True)
def _parse_component(cls, msg_chain):
if isinstance(msg_chain, (str, MessageComponent)):
msg_chain = [msg_chain]
if not msg_chain:
msg_chain = []
return cls._parse_message_chain(msg_chain)
@classmethod
def parse_obj(cls, msg_chain: typing.Iterable):
"""通过列表形式的消息链,构造对应的 `MessageChain` 对象。
Args:
msg_chain: 列表形式的消息链
"""
result = cls._parse_message_chain(msg_chain)
return cls(__root__=result)
def __init__(self, __root__: typing.Iterable[MessageComponent] = None):
super().__init__(__root__=__root__)
def __str__(self):
return "".join(str(component) for component in self.__root__)
def __repr__(self):
return f'{self.__class__.__name__}({self.__root__!r})'
def __iter__(self):
yield from self.__root__
def get_first(self,
t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]:
"""获取消息链中第一个符合类型的消息组件。"""
for component in self:
if isinstance(component, t):
return component
return None
@typing.overload
def __getitem__(self, index: int) -> MessageComponent:
...
@typing.overload
def __getitem__(self, index: slice) -> typing.List[MessageComponent]:
...
@typing.overload
def __getitem__(self,
index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]:
...
@typing.overload
def __getitem__(
self, index: typing.Tuple[typing.Type[TMessageComponent], int]
) -> typing.List[TMessageComponent]:
...
def __getitem__(
self, index: typing.Union[int, slice, typing.Type[TMessageComponent],
typing.Tuple[typing.Type[TMessageComponent], int]]
) -> typing.Union[MessageComponent, typing.List[MessageComponent],
typing.List[TMessageComponent]]:
return self.get(index)
def __setitem__(
self, key: typing.Union[int, slice],
value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent,
str]]]
):
if isinstance(value, str):
value = Plain(value)
if isinstance(value, typing.Iterable):
value = (Plain(c) if isinstance(c, str) else c for c in value)
self.__root__[key] = value # type: ignore
def __delitem__(self, key: typing.Union[int, slice]):
del self.__root__[key]
def __reversed__(self) -> typing.Iterable[MessageComponent]:
return reversed(self.__root__)
def has(
self, sub: typing.Union[MessageComponent, typing.Type[MessageComponent],
'MessageChain', str]
) -> bool:
"""判断消息链中:
1. 是否有某个消息组件
2. 是否有某个类型的消息组件
Args:
sub (`Union[MessageComponent, Type[MessageComponent], 'MessageChain', str]`):
若为 `MessageComponent`则判断该组件是否在消息链中
若为 `Type[MessageComponent]`则判断该组件类型是否在消息链中
Returns:
bool: 是否找到
"""
if isinstance(sub, type): # 检测消息链中是否有某种类型的对象
for i in self:
if type(i) is sub:
return True
return False
if isinstance(sub, MessageComponent): # 检查消息链中是否有某个组件
for i in self:
if i == sub:
return True
return False
raise TypeError(f"类型不匹配,当前类型:{type(sub)}")
def __contains__(self, sub) -> bool:
return self.has(sub)
def __ge__(self, other):
return other in self
def __len__(self) -> int:
return len(self.__root__)
def __add__(
self, other: typing.Union['MessageChain', MessageComponent, str]
) -> 'MessageChain':
if isinstance(other, MessageChain):
return self.__class__(self.__root__ + other.__root__)
if isinstance(other, str):
return self.__class__(self.__root__ + [Plain(other)])
if isinstance(other, MessageComponent):
return self.__class__(self.__root__ + [other])
return NotImplemented
def __radd__(self, other: typing.Union[MessageComponent, str]) -> 'MessageChain':
if isinstance(other, MessageComponent):
return self.__class__([other] + self.__root__)
if isinstance(other, str):
return self.__class__(
[typing.cast(MessageComponent, Plain(other))] + self.__root__
)
return NotImplemented
def __mul__(self, other: int):
if isinstance(other, int):
return self.__class__(self.__root__ * other)
return NotImplemented
def __rmul__(self, other: int):
return self.__mul__(other)
def __iadd__(self, other: typing.Iterable[typing.Union[MessageComponent, str]]):
self.extend(other)
def __imul__(self, other: int):
if isinstance(other, int):
self.__root__ *= other
return NotImplemented
def index(
self,
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
i: int = 0,
j: int = -1
) -> int:
"""返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。
Args:
x (`Union[MessageComponent, Type[MessageComponent]]`):
要查找的消息元素或消息元素类型
i: 从哪个位置开始查找
j: 查找到哪个位置结束
Returns:
int: 如果找到则返回索引号
Raises:
ValueError: 没有找到
TypeError: 类型不匹配
"""
if isinstance(x, type):
l = len(self)
if i < 0:
i += l
if i < 0:
i = 0
if j < 0:
j += l
if j > l:
j = l
for index in range(i, j):
if type(self[index]) is x:
return index
raise ValueError("消息链中不存在该类型的组件。")
if isinstance(x, MessageComponent):
return self.__root__.index(x, i, j)
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int:
"""返回消息链中 x 出现的次数。
Args:
x (`Union[MessageComponent, Type[MessageComponent]]`):
要查找的消息元素或消息元素类型
Returns:
int: 次数
"""
if isinstance(x, type):
return sum(1 for i in self if type(i) is x)
if isinstance(x, MessageComponent):
return self.__root__.count(x)
raise TypeError(f"类型不匹配,当前类型:{type(x)}")
def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]):
"""将另一个消息链中的元素添加到消息链末尾。
Args:
x: 另一个消息链也可为消息元素或字符串元素的序列
"""
self.__root__.extend(Plain(c) if isinstance(c, str) else c for c in x)
def append(self, x: typing.Union[MessageComponent, str]):
"""将一个消息元素或字符串元素添加到消息链末尾。
Args:
x: 消息元素或字符串元素
"""
self.__root__.append(Plain(x) if isinstance(x, str) else x)
def insert(self, i: int, x: typing.Union[MessageComponent, str]):
"""将一个消息元素或字符串添加到消息链中指定位置。
Args:
i: 插入位置
x: 消息元素或字符串元素
"""
self.__root__.insert(i, Plain(x) if isinstance(x, str) else x)
def pop(self, i: int = -1) -> MessageComponent:
"""从消息链中移除并返回指定位置的元素。
Args:
i: 移除位置默认为末尾
Returns:
MessageComponent: 移除的元素
"""
return self.__root__.pop(i)
def remove(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]):
"""从消息链中移除指定元素或指定类型的一个元素。
Args:
x: 指定的元素或元素类型
"""
if isinstance(x, type):
self.pop(self.index(x))
if isinstance(x, MessageComponent):
self.__root__.remove(x)
def exclude(
self,
x: typing.Union[MessageComponent, typing.Type[MessageComponent]],
count: int = -1
) -> 'MessageChain':
"""返回移除指定元素或指定类型的元素后剩余的消息链。
Args:
x: 指定的元素或元素类型
count: 至多移除的数量默认为全部移除
Returns:
MessageChain: 剩余的消息链
"""
def _exclude():
nonlocal count
x_is_type = isinstance(x, type)
for c in self:
if count > 0 and ((x_is_type and type(c) is x) or c == x):
count -= 1
continue
yield c
return self.__class__(_exclude())
def reverse(self):
"""将消息链原地翻转。"""
self.__root__.reverse()
@classmethod
def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]):
return cls(
Plain(c) if isinstance(c, str) else c
for c in itertools.chain(*args)
)
@property
def source(self) -> typing.Optional['Source']:
"""获取消息链中的 `Source` 对象。"""
return self.get_first(Source)
@property
def message_id(self) -> int:
"""获取消息链的 message_id若无法获取返回 -1。"""
source = self.source
return source.id if source else -1
TMessage = typing.Union[MessageChain, typing.Iterable[typing.Union[MessageComponent, str]],
MessageComponent, str]
"""可以转化为 MessageChain 的类型。"""
class Source(MessageComponent):
"""源。包含消息的基本信息。"""
type: str = "Source"
"""消息组件类型。"""
id: int
"""消息的识别号用于引用回复Source 类型永远为 MessageChain 的第一个元素)。"""
time: datetime
"""消息时间。"""
class Plain(MessageComponent):
"""纯文本。"""
type: str = "Plain"
"""消息组件类型。"""
text: str
"""文字消息。"""
def __str__(self):
return self.text
def __repr__(self):
return f'Plain({self.text!r})'
class Quote(MessageComponent):
"""引用。"""
type: str = "Quote"
"""消息组件类型。"""
id: typing.Optional[int] = None
"""被引用回复的原消息的 message_id。"""
group_id: typing.Optional[int] = None
"""被引用回复的原消息所接收的群号当为好友消息时为0。"""
sender_id: typing.Optional[int] = None
"""被引用回复的原消息的发送者的QQ号。"""
target_id: typing.Optional[int] = None
"""被引用回复的原消息的接收者者的QQ号或群号"""
origin: MessageChain
"""被引用回复的原消息的消息链对象。"""
@pydantic.validator("origin", always=True, pre=True)
def origin_formater(cls, v):
return MessageChain.parse_obj(v)
class At(MessageComponent):
"""At某人。"""
type: str = "At"
"""消息组件类型。"""
target: int
"""群员 QQ 号。"""
display: typing.Optional[str] = None
"""At时显示的文字发送消息时无效自动使用群名片。"""
def __eq__(self, other):
return isinstance(other, At) and self.target == other.target
def __str__(self):
return f"@{self.display or self.target}"
def as_mirai_code(self) -> str:
return f"[mirai:at:{self.target}]"
class AtAll(MessageComponent):
"""At全体。"""
type: str = "AtAll"
"""消息组件类型。"""
def __str__(self):
return "@全体成员"
def as_mirai_code(self) -> str:
return f"[mirai:atall]"
class Image(MessageComponent):
"""图片。"""
type: str = "Image"
"""消息组件类型。"""
image_id: typing.Optional[str] = None
"""图片的 image_id群图片与好友图片格式不同。不为空时将忽略 url 属性。"""
url: typing.Optional[pydantic.HttpUrl] = None
"""图片的 URL发送时可作网络图片的链接接收时为腾讯图片服务器的链接可用于图片下载。"""
path: typing.Union[str, Path, None] = None
"""图片的路径,发送本地图片。"""
base64: typing.Optional[str] = None
"""图片的 Base64 编码。"""
def __eq__(self, other):
return isinstance(
other, Image
) and self.type == other.type and self.uuid == other.uuid
def __str__(self):
return '[图片]'
def as_mirai_code(self) -> str:
return f"[mirai:image:{self.image_id}]"
@pydantic.validator('path')
def validate_path(cls, path: typing.Union[str, Path, None]):
"""修复 path 参数的行为,使之相对于 YiriMirai 的启动路径。"""
if path:
try:
return str(Path(path).resolve(strict=True))
except FileNotFoundError:
raise ValueError(f"无效路径:{path}")
else:
return path
@property
def uuid(self):
image_id = self.image_id
if image_id[0] == '{': # 群图片
image_id = image_id[1:37]
elif image_id[0] == '/': # 好友图片
image_id = image_id[1:]
return image_id
async def download(
self,
filename: typing.Union[str, Path, None] = None,
directory: typing.Union[str, Path, None] = None,
determine_type: bool = True
):
"""下载图片到本地。
Args:
filename: 下载到本地的文件路径 `directory` 二选一
directory: 下载到本地的文件夹路径 `filename` 二选一
determine_type: 是否自动根据图片类型确定拓展名默认为 True
"""
if not self.url:
logger.warning(f'图片 `{self.uuid}` 无 url 参数,下载失败。')
return
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(self.url)
response.raise_for_status()
content = response.content
if filename:
path = Path(filename)
if determine_type:
import imghdr
path = path.with_suffix(
'.' + str(imghdr.what(None, content))
)
path.parent.mkdir(parents=True, exist_ok=True)
elif directory:
import imghdr
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
path = path / f'{self.uuid}.{imghdr.what(None, content)}'
else:
raise ValueError("请指定文件路径或文件夹路径!")
import aiofiles
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
return path
@classmethod
async def from_local(
cls,
filename: typing.Union[str, Path, None] = None,
content: typing.Optional[bytes] = None,
) -> "Image":
"""从本地文件路径加载图片,以 base64 的形式传递。
Args:
filename: 从本地文件路径加载图片 `content` 二选一
content: 从本地文件内容加载图片 `filename` 二选一
Returns:
Image: 图片对象
"""
if content:
pass
elif filename:
path = Path(filename)
import aiofiles
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
else:
raise ValueError("请指定图片路径或图片内容!")
import base64
img = cls(base64=base64.b64encode(content).decode())
return img
@classmethod
def from_unsafe_path(cls, path: typing.Union[str, Path]) -> "Image":
"""从不安全的路径加载图片。
Args:
path: 从不安全的路径加载图片
Returns:
Image: 图片对象
"""
return cls.construct(path=str(path))
class Unknown(MessageComponent):
"""未知。"""
type: str = "Unknown"
"""消息组件类型。"""
text: str
"""文本。"""
class Voice(MessageComponent):
"""语音。"""
type: str = "Voice"
"""消息组件类型。"""
voice_id: typing.Optional[str] = None
"""语音的 voice_id不为空时将忽略 url 属性。"""
url: typing.Optional[str] = None
"""语音的 URL发送时可作网络语音的链接接收时为腾讯语音服务器的链接可用于语音下载。"""
path: typing.Optional[str] = None
"""语音的路径,发送本地语音。"""
base64: typing.Optional[str] = None
"""语音的 Base64 编码。"""
length: typing.Optional[int] = None
"""语音的长度,单位为秒。"""
@pydantic.validator('path')
def validate_path(cls, path: typing.Optional[str]):
"""修复 path 参数的行为,使之相对于 YiriMirai 的启动路径。"""
if path:
try:
return str(Path(path).resolve(strict=True))
except FileNotFoundError:
raise ValueError(f"无效路径:{path}")
else:
return path
def __str__(self):
return '[语音]'
async def download(
self,
filename: typing.Union[str, Path, None] = None,
directory: typing.Union[str, Path, None] = None
):
"""下载语音到本地。
语音采用 silk v3 格式silk 格式的编码解码请使用 [graiax-silkcoder](https://pypi.org/project/graiax-silkcoder/)
Args:
filename: 下载到本地的文件路径 `directory` 二选一
directory: 下载到本地的文件夹路径 `filename` 二选一
"""
if not self.url:
logger.warning(f'语音 `{self.voice_id}` 无 url 参数,下载失败。')
return
import httpx
async with httpx.AsyncClient() as client:
response = await client.get(self.url)
response.raise_for_status()
content = response.content
if filename:
path = Path(filename)
path.parent.mkdir(parents=True, exist_ok=True)
elif directory:
path = Path(directory)
path.mkdir(parents=True, exist_ok=True)
path = path / f'{self.voice_id}.silk'
else:
raise ValueError("请指定文件路径或文件夹路径!")
import aiofiles
async with aiofiles.open(path, 'wb') as f:
await f.write(content)
@classmethod
async def from_local(
cls,
filename: typing.Union[str, Path, None] = None,
content: typing.Optional[bytes] = None,
) -> "Voice":
"""从本地文件路径加载语音,以 base64 的形式传递。
Args:
filename: 从本地文件路径加载语音 `content` 二选一
content: 从本地文件内容加载语音 `filename` 二选一
"""
if content:
pass
if filename:
path = Path(filename)
import aiofiles
async with aiofiles.open(path, 'rb') as f:
content = await f.read()
else:
raise ValueError("请指定语音路径或语音内容!")
import base64
img = cls(base64=base64.b64encode(content).decode())
return img
class ForwardMessageNode(pydantic.BaseModel):
"""合并转发中的一条消息。"""
sender_id: typing.Optional[int] = None
"""发送人QQ号。"""
sender_name: typing.Optional[str] = None
"""显示名称。"""
message_chain: typing.Optional[MessageChain] = None
"""消息内容。"""
message_id: typing.Optional[int] = None
"""消息的 message_id可以只使用此属性从缓存中读取消息内容。"""
time: typing.Optional[datetime] = None
"""发送时间。"""
@pydantic.validator('message_chain', check_fields=False)
def _validate_message_chain(cls, value: typing.Union[MessageChain, list]):
if isinstance(value, list):
return MessageChain.parse_obj(value)
return value
@classmethod
def create(
cls, sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], message: MessageChain
) -> 'ForwardMessageNode':
"""从消息链生成转发消息。
Args:
sender: 发送人
message: 消息内容
Returns:
ForwardMessageNode: 生成的一条消息
"""
return ForwardMessageNode(
sender_id=sender.id,
sender_name=sender.get_name(),
message_chain=message
)
class Forward(MessageComponent):
"""合并转发。"""
type: str = "Forward"
"""消息组件类型。"""
node_list: typing.List[ForwardMessageNode]
"""转发消息节点列表。"""
def __init__(self, *args, **kwargs):
if len(args) == 1:
self.node_list = args[0]
super().__init__(**kwargs)
super().__init__(*args, **kwargs)
def __str__(self):
return '[聊天记录]'
class File(MessageComponent):
"""文件。"""
type: str = "File"
"""消息组件类型。"""
id: str
"""文件识别 ID。"""
name: str
"""文件名称。"""
size: int
"""文件大小。"""
def __str__(self):
return f'[文件]{self.name}'

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import typing
import abc
import pydantic
import mirai
# import mirai
from . import events
from ..provider.tools import entities as tools_entities
from ..core import app
from ..platform.types import message as platform_message
def register(
@ -174,7 +175,7 @@ class EventContext:
self.__return_value__[key] = []
self.__return_value__[key].append(ret)
async def reply(self, message_chain: mirai.MessageChain):
async def reply(self, message_chain: platform_message.MessageChain):
"""回复此次消息请求
Args:
@ -190,7 +191,7 @@ class EventContext:
self,
target_type: str,
target_id: str,
message: mirai.MessageChain
message: platform_message.MessageChain
):
"""主动发送消息

View File

@ -3,10 +3,11 @@ from __future__ import annotations
import typing
import pydantic
import mirai
# import mirai
from ..core import entities as core_entities
from ..provider import entities as llm_entities
from ..platform.types import message as platform_message
class BaseEventModel(pydantic.BaseModel):
@ -31,7 +32,7 @@ class PersonMessageReceived(BaseEventModel):
sender_id: int
"""发送者ID(QQ号)"""
message_chain: mirai.MessageChain
message_chain: platform_message.MessageChain
class GroupMessageReceived(BaseEventModel):
@ -43,7 +44,7 @@ class GroupMessageReceived(BaseEventModel):
sender_id: int
message_chain: mirai.MessageChain
message_chain: platform_message.MessageChain
class PersonNormalMessageReceived(BaseEventModel):

View File

@ -4,7 +4,9 @@ import typing
import enum
import pydantic
import mirai
# import mirai
from ..platform.types import message as platform_message
class FunctionCall(pydantic.BaseModel):
@ -79,7 +81,7 @@ class Message(pydantic.BaseModel):
else:
return '未知消息'
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
def get_content_mirai_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None:
"""将内容转换为 Mirai MessageChain 对象
Args:
@ -89,15 +91,15 @@ class Message(pydantic.BaseModel):
if self.content is None:
return None
elif isinstance(self.content, str):
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)])
elif isinstance(self.content, list):
mc = []
for ce in self.content:
if ce.type == 'text':
mc.append(mirai.Plain(ce.text))
mc.append(platform_message.Plain(ce.text))
elif ce.type == 'image_url':
if ce.image_url.url.startswith("http"):
mc.append(mirai.Image(url=ce.image_url.url))
mc.append(platform_message.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
@ -105,15 +107,15 @@ class Message(pydantic.BaseModel):
if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1]
mc.append(mirai.Image(base64=b64_str))
mc.append(platform_message.Image(base64=b64_str))
# 找第一个文字组件
if prefix_text:
for i, c in enumerate(mc):
if isinstance(c, mirai.Plain):
mc[i] = mirai.Plain(prefix_text+c.text)
if isinstance(c, platform_message.Plain):
mc[i] = platform_message.Plain(prefix_text+c.text)
break
else:
mc.insert(0, mirai.Plain(prefix_text))
mc.insert(0, platform_message.Plain(prefix_text))
return mirai.MessageChain(mc)
return platform_message.MessageChain(mc)