diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 12f9b83..a5947b0 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,10 +8,10 @@ body: label: 消息平台适配器 description: "连接QQ使用的框架" options: - - yiri-mirai(Mirai) - Nakuru(go-cqhttp) - aiocqhttp(使用 OneBot 协议接入的) - qq-botpy(QQ官方API) + - yiri-mirai(Mirai) validations: required: false - type: input diff --git a/.github/dependabot.yml b/.github/dependabot.yml index d1973eb..53c5849 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -10,5 +10,4 @@ updates: schedule: interval: "weekly" allow: - - dependency-name: "yiri-mirai-rc" - dependency-name: "openai" diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 6b636c2..a4b3a1a 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -6,8 +6,11 @@ ### PR 作者完成 +*请在方括号间写`x`以打勾 + - [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/QChatGPT/blob/master/CONTRIBUTING.md)了吗? - [ ] 与项目所有者沟通过了吗? +- [ ] 我确定已自行测试所作的更改,确保功能符合预期。 ### 项目所有者完成 diff --git a/pkg/command/entities.py b/pkg/command/entities.py index 8697551..463dfe3 100644 --- a/pkg/command/entities.py +++ b/pkg/command/entities.py @@ -3,10 +3,10 @@ from __future__ import annotations import typing import pydantic -import mirai from ..core import app, entities as core_entities from . import errors, operator +from ..platform.types import message as platform_message class CommandReturn(pydantic.BaseModel): @@ -17,7 +17,7 @@ class CommandReturn(pydantic.BaseModel): """文本 """ - image: typing.Optional[mirai.Image] = None + image: typing.Optional[platform_message.Image] = None """弃用""" image_url: typing.Optional[str] = None diff --git a/pkg/core/bootutils/deps.py b/pkg/core/bootutils/deps.py index e7df3e8..c392f80 100644 --- a/pkg/core/bootutils/deps.py +++ b/pkg/core/bootutils/deps.py @@ -5,7 +5,6 @@ required_deps = { "openai": "openai", "anthropic": "anthropic", "colorlog": "colorlog", - "mirai": "yiri-mirai-rc", "aiocqhttp": "aiocqhttp", "botpy": "qq-botpy", "PIL": "pillow", diff --git a/pkg/core/entities.py b/pkg/core/entities.py index 6305a0e..67b0566 100644 --- a/pkg/core/entities.py +++ b/pkg/core/entities.py @@ -6,13 +6,15 @@ import datetime import asyncio import pydantic -import mirai from ..provider import entities as llm_entities from ..provider.modelmgr import entities from ..provider.sysprompt import entities as sysprompt_entities from ..provider.tools import entities as tools_entities from ..platform import adapter as msadapter +from ..platform.types import message as platform_message +from ..platform.types import events as platform_events +from ..platform.types import entities as platform_entities class LauncherTypes(enum.Enum): @@ -40,10 +42,10 @@ class Query(pydantic.BaseModel): sender_id: int """发送者ID,platform处理阶段设置""" - message_event: mirai.MessageEvent + message_event: platform_events.MessageEvent """事件,platform收到的原始事件""" - message_chain: mirai.MessageChain + message_chain: platform_message.MessageChain """消息链,platform收到的原始消息链""" adapter: msadapter.MessageSourceAdapter @@ -67,10 +69,10 @@ class Query(pydantic.BaseModel): use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None """使用的函数,由前置处理器阶段设置""" - resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = [] + resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] = [] """由Process阶段生成的回复消息对象列表""" - resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None + resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None """回复消息链,从resp_messages包装而得""" # ======= 内部保留 ======= @@ -108,7 +110,7 @@ class Session(pydantic.BaseModel): using_conversation: typing.Optional[Conversation] = None - conversations: typing.Optional[list[Conversation]] = [] + conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list) create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now) diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 29e66cc..f7376b6 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -1,9 +1,5 @@ from __future__ import annotations -import mirai -import mirai.models -import mirai.models.message - from ...core import app from .. import stage, entities, stagemgr @@ -12,6 +8,9 @@ from ...config import manager as cfg_mgr from . import filter as filter_model, entities as filter_entities from .filters import cntignore, banwords, baiduexamine from ...provider import entities as llm_entities +from ...platform.types import message as platform_message +from ...platform.types import events as platform_events +from ...platform.types import entities as platform_entities @stage.stage_class('PostContentFilterStage') @@ -89,8 +88,8 @@ class ContentFilterStage(stage.PipelineStage): elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 message = result.replacement - query.message_chain = mirai.MessageChain( - mirai.Plain(message) + query.message_chain = platform_message.MessageChain( + platform_message.Plain(message) ) return entities.StageProcessResult( @@ -148,7 +147,7 @@ class ContentFilterStage(stage.PipelineStage): contain_non_text = False - text_components = [mirai.Plain, mirai.models.message.Source] + text_components = [platform_message.Plain, platform_message.Source] for me in query.message_chain: if type(me) not in text_components: diff --git a/pkg/pipeline/controller.py b/pkg/pipeline/controller.py index 677db31..0f07e06 100644 --- a/pkg/pipeline/controller.py +++ b/pkg/pipeline/controller.py @@ -4,11 +4,11 @@ import asyncio import typing import traceback -import mirai from ..core import app, entities from . import entities as pipeline_entities from ..plugin import events +from ..platform.types import message as platform_message class Controller: @@ -73,11 +73,11 @@ class Controller: # 处理str类型 if isinstance(result.user_notice, str): - result.user_notice = mirai.MessageChain( - mirai.Plain(result.user_notice) + result.user_notice = platform_message.MessageChain( + platform_message.Plain(result.user_notice) ) elif isinstance(result.user_notice, list): - result.user_notice = mirai.MessageChain( + result.user_notice = platform_message.MessageChain( *result.user_notice ) diff --git a/pkg/pipeline/entities.py b/pkg/pipeline/entities.py index e8cfc42..cbeb3d0 100644 --- a/pkg/pipeline/entities.py +++ b/pkg/pipeline/entities.py @@ -4,8 +4,7 @@ import enum import typing import pydantic -import mirai -import mirai.models.message as mirai_message +from ..platform.types import message as platform_message from ..core import entities @@ -25,13 +24,9 @@ class StageProcessResult(pydantic.BaseModel): new_query: entities.Query - user_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] + user_notice: typing.Optional[typing.Union[str, list[platform_message.MessageComponent], platform_message.MessageChain, None]] = [] """只要设置了就会发送给用户""" - # TODO delete - # admin_notice: typing.Optional[typing.Union[str, list[mirai_message.MessageComponent], mirai.MessageChain, None]] = [] - """只要设置了就会发送给管理员""" - console_notice: typing.Optional[str] = '' """只要设置了就会输出到控制台""" diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 0ab34ab..ecb745d 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -3,7 +3,6 @@ import os import traceback from PIL import Image, ImageDraw, ImageFont -from mirai.models.message import MessageComponent, Plain, MessageChain from ...core import app from . import strategy @@ -11,6 +10,7 @@ from .strategies import image, forward from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...config import manager as cfg_mgr +from ...platform.types import message as platform_message @stage.stage_class("LongTextProcessStage") @@ -63,14 +63,14 @@ class LongTextProcessStage(stage.PipelineStage): contains_non_plain = False for msg in query.resp_message_chain[-1]: - if not isinstance(msg, Plain): + if not isinstance(msg, platform_message.Plain): contains_non_plain = True break if contains_non_plain: self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']: - query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) + query.resp_message_chain[-1] = platform_message.MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query)) return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/pipeline/longtext/strategies/forward.py b/pkg/pipeline/longtext/strategies/forward.py index 4a79031..c39c920 100644 --- a/pkg/pipeline/longtext/strategies/forward.py +++ b/pkg/pipeline/longtext/strategies/forward.py @@ -2,15 +2,14 @@ from __future__ import annotations import typing -from mirai.models import MessageChain -from mirai.models.message import MessageComponent, ForwardMessageNode -from mirai.models.base import MiraiBaseModel +import pydantic from .. import strategy as strategy_model from ....core import entities as core_entities +from ....platform.types import message as platform_message -class ForwardMessageDiaplay(MiraiBaseModel): +class ForwardMessageDiaplay(pydantic.BaseModel): title: str = "群聊的聊天记录" brief: str = "[聊天记录]" source: str = "聊天记录" @@ -18,13 +17,13 @@ class ForwardMessageDiaplay(MiraiBaseModel): summary: str = "查看x条转发消息" -class Forward(MessageComponent): +class Forward(platform_message.MessageComponent): """合并转发。""" type: str = "Forward" """消息组件类型。""" display: ForwardMessageDiaplay """显示信息""" - node_list: typing.List[ForwardMessageNode] + node_list: typing.List[platform_message.ForwardMessageNode] """转发消息节点列表。""" def __init__(self, *args, **kwargs): if len(args) == 1: @@ -39,7 +38,7 @@ class Forward(MessageComponent): @strategy_model.strategy_class("forward") class ForwardComponentStrategy(strategy_model.LongTextStrategy): - async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: display = ForwardMessageDiaplay( title="群聊的聊天记录", brief="[聊天记录]", @@ -49,10 +48,10 @@ class ForwardComponentStrategy(strategy_model.LongTextStrategy): ) node_list = [ - ForwardMessageNode( + platform_message.ForwardMessageNode( sender_id=query.adapter.bot_account_id, sender_name='QQ用户', - message_chain=MessageChain([message]) + message_chain=platform_message.MessageChain([message]) ) ] diff --git a/pkg/pipeline/longtext/strategies/image.py b/pkg/pipeline/longtext/strategies/image.py index f96f03c..9e32e59 100644 --- a/pkg/pipeline/longtext/strategies/image.py +++ b/pkg/pipeline/longtext/strategies/image.py @@ -8,8 +8,7 @@ import re from PIL import Image, ImageDraw, ImageFont -from mirai.models import MessageChain, Image as ImageComponent -from mirai.models.message import MessageComponent +from ....platform.types import message as platform_message from .. import strategy as strategy_model from ....core import entities as core_entities @@ -23,7 +22,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): async def initialize(self): self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8") - async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: img_path = self.text_to_image( text_str=message, save_as='temp/{}.png'.format(int(time.time())) @@ -46,7 +45,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy): os.remove(compressed_path) return [ - ImageComponent( + platform_message.Image( base64=b64.decode('utf-8'), ) ] diff --git a/pkg/pipeline/longtext/strategy.py b/pkg/pipeline/longtext/strategy.py index 5d7e24f..6f66bbf 100644 --- a/pkg/pipeline/longtext/strategy.py +++ b/pkg/pipeline/longtext/strategy.py @@ -2,11 +2,10 @@ from __future__ import annotations import abc import typing -import mirai -from mirai.models.message import MessageComponent from ...core import app from ...core import entities as core_entities +from ...platform.types import message as platform_message preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] @@ -51,7 +50,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta): pass @abc.abstractmethod - async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]: + async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: """处理长文本 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 @@ -61,6 +60,6 @@ class LongTextStrategy(metaclass=abc.ABCMeta): query (core_entities.Query): 此次请求的上下文对象 Returns: - list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表 + list[platform_message.MessageComponent]: 转换后的 平台 消息组件列表 """ return [] diff --git a/pkg/pipeline/pool.py b/pkg/pipeline/pool.py index ba7f999..45f16e6 100644 --- a/pkg/pipeline/pool.py +++ b/pkg/pipeline/pool.py @@ -2,10 +2,11 @@ from __future__ import annotations import asyncio -import mirai from ..core import entities from ..platform import adapter as msadapter +from ..platform.types import message as platform_message +from ..platform.types import events as platform_events class QueryPool: @@ -30,8 +31,8 @@ class QueryPool: launcher_type: entities.LauncherTypes, launcher_id: int, sender_id: int, - message_event: mirai.MessageEvent, - message_chain: mirai.MessageChain, + message_event: platform_events.MessageEvent, + message_chain: platform_message.MessageChain, adapter: msadapter.MessageSourceAdapter ) -> entities.Query: async with self.condition: diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index ebe4d31..3a71a84 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -1,11 +1,11 @@ from __future__ import annotations -import mirai from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...provider import entities as llm_entities from ...plugin import events +from ...platform.types import message as platform_message @stage.stage_class("PreProcessor") @@ -55,11 +55,11 @@ class PreProcessor(stage.PipelineStage): content_list = [] for me in query.message_chain: - if isinstance(me, mirai.Plain): + if isinstance(me, platform_message.Plain): content_list.append( llm_entities.ContentElement.from_text(me.text) ) - elif isinstance(me, mirai.Image): + elif isinstance(me, platform_message.Image): if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported: if me.url is not None: content_list.append( diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index cb8899b..6e192b7 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -5,7 +5,6 @@ import time import traceback import json -import mirai from .. import handler from ... import entities @@ -13,6 +12,8 @@ from ....core import entities as core_entities from ....provider import entities as llm_entities, runnermgr from ....plugin import events +from ....platform.types import message as platform_message + class ChatMessageHandler(handler.MessageHandler): @@ -40,7 +41,7 @@ class ChatMessageHandler(handler.MessageHandler): if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: - mc = mirai.MessageChain(event_ctx.event.reply) + mc = platform_message.MessageChain(event_ctx.event.reply) query.resp_messages.append(mc) diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index 8c8fb8b..cec64a4 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -1,13 +1,13 @@ from __future__ import annotations import typing -import mirai from .. import handler from ... import entities from ....core import entities as core_entities from ....provider import entities as llm_entities from ....plugin import events +from ....platform.types import message as platform_message class CommandHandler(handler.MessageHandler): @@ -46,7 +46,7 @@ class CommandHandler(handler.MessageHandler): if event_ctx.is_prevented_default(): if event_ctx.event.reply is not None: - mc = mirai.MessageChain(event_ctx.event.reply) + mc = platform_message.MessageChain(event_ctx.event.reply) query.resp_messages.append(mc) @@ -63,8 +63,8 @@ class CommandHandler(handler.MessageHandler): else: if event_ctx.event.alter is not None: - query.message_chain = mirai.MessageChain([ - mirai.Plain(event_ctx.event.alter) + query.message_chain = platform_message.MessageChain([ + platform_message.Plain(event_ctx.event.alter) ]) session = await self.ap.sess_mgr.get_session(query) diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index d3af14e..d3dd83f 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -3,7 +3,6 @@ from __future__ import annotations import random import asyncio -import mirai from ...core import app diff --git a/pkg/pipeline/resprule/entities.py b/pkg/pipeline/resprule/entities.py index ffee308..2292715 100644 --- a/pkg/pipeline/resprule/entities.py +++ b/pkg/pipeline/resprule/entities.py @@ -1,9 +1,10 @@ import pydantic -import mirai + +from ...platform.types import message as platform_message class RuleJudgeResult(pydantic.BaseModel): matching: bool = False - replacement: mirai.MessageChain = None + replacement: platform_message.MessageChain = None diff --git a/pkg/pipeline/resprule/resprule.py b/pkg/pipeline/resprule/resprule.py index b7fdb37..77858f0 100644 --- a/pkg/pipeline/resprule/resprule.py +++ b/pkg/pipeline/resprule/resprule.py @@ -1,6 +1,5 @@ from __future__ import annotations -import mirai from ...core import app from . import entities as rule_entities, rule diff --git a/pkg/pipeline/resprule/rule.py b/pkg/pipeline/resprule/rule.py index bfab415..ad69d8a 100644 --- a/pkg/pipeline/resprule/rule.py +++ b/pkg/pipeline/resprule/rule.py @@ -2,11 +2,11 @@ from __future__ import annotations import abc import typing -import mirai - from ...core import app, entities as core_entities from . import entities +from ...platform.types import message as platform_message + preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] @@ -35,7 +35,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta): async def match( self, message_text: str, - message_chain: mirai.MessageChain, + message_chain: platform_message.MessageChain, rule_dict: dict, query: core_entities.Query ) -> entities.RuleJudgeResult: diff --git a/pkg/pipeline/resprule/rules/atbot.py b/pkg/pipeline/resprule/rules/atbot.py index 4b39409..a0b7a7c 100644 --- a/pkg/pipeline/resprule/rules/atbot.py +++ b/pkg/pipeline/resprule/rules/atbot.py @@ -1,10 +1,10 @@ from __future__ import annotations -import mirai from .. import rule as rule_model from .. import entities from ....core import entities as core_entities +from ....platform.types import message as platform_message @rule_model.rule_class("at-bot") @@ -13,16 +13,16 @@ class AtBotRule(rule_model.GroupRespondRule): async def match( self, message_text: str, - message_chain: mirai.MessageChain, + message_chain: platform_message.MessageChain, rule_dict: dict, query: core_entities.Query ) -> entities.RuleJudgeResult: - if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']: - message_chain.remove(mirai.At(query.adapter.bot_account_id)) + if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: + message_chain.remove(platform_message.At(query.adapter.bot_account_id)) - if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的 - message_chain.remove(mirai.At(query.adapter.bot_account_id)) + if message_chain.has(platform_message.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的 + message_chain.remove(platform_message.At(query.adapter.bot_account_id)) return entities.RuleJudgeResult( matching=True, diff --git a/pkg/pipeline/resprule/rules/prefix.py b/pkg/pipeline/resprule/rules/prefix.py index 98b5032..fb7bbcf 100644 --- a/pkg/pipeline/resprule/rules/prefix.py +++ b/pkg/pipeline/resprule/rules/prefix.py @@ -1,8 +1,8 @@ -import mirai from .. import rule as rule_model from .. import entities from ....core import entities as core_entities +from ....platform.types import message as platform_message @rule_model.rule_class("prefix") @@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule): async def match( self, message_text: str, - message_chain: mirai.MessageChain, + message_chain: platform_message.MessageChain, rule_dict: dict, query: core_entities.Query ) -> entities.RuleJudgeResult: @@ -22,7 +22,7 @@ class PrefixRule(rule_model.GroupRespondRule): # 查找第一个plain元素 for me in message_chain: - if isinstance(me, mirai.Plain): + if isinstance(me, platform_message.Plain): me.text = me.text[len(prefix):] return entities.RuleJudgeResult( diff --git a/pkg/pipeline/resprule/rules/random.py b/pkg/pipeline/resprule/rules/random.py index 80acf6a..0178f2c 100644 --- a/pkg/pipeline/resprule/rules/random.py +++ b/pkg/pipeline/resprule/rules/random.py @@ -1,10 +1,10 @@ import random -import mirai from .. import rule as rule_model from .. import entities from ....core import entities as core_entities +from ....platform.types import message as platform_message @rule_model.rule_class("random") @@ -13,7 +13,7 @@ class RandomRespRule(rule_model.GroupRespondRule): async def match( self, message_text: str, - message_chain: mirai.MessageChain, + message_chain: platform_message.MessageChain, rule_dict: dict, query: core_entities.Query ) -> entities.RuleJudgeResult: diff --git a/pkg/pipeline/resprule/rules/regexp.py b/pkg/pipeline/resprule/rules/regexp.py index aaa4644..f5f5b3f 100644 --- a/pkg/pipeline/resprule/rules/regexp.py +++ b/pkg/pipeline/resprule/rules/regexp.py @@ -1,10 +1,10 @@ import re -import mirai from .. import rule as rule_model from .. import entities from ....core import entities as core_entities +from ....platform.types import message as platform_message @rule_model.rule_class("regexp") @@ -13,7 +13,7 @@ class RegExpRule(rule_model.GroupRespondRule): async def match( self, message_text: str, - message_chain: mirai.MessageChain, + message_chain: platform_message.MessageChain, rule_dict: dict, query: core_entities.Query ) -> entities.RuleJudgeResult: diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index e6ce99a..1ffb314 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -2,7 +2,6 @@ from __future__ import annotations import typing -import mirai from ...core import app, entities as core_entities from .. import entities @@ -10,6 +9,7 @@ from .. import stage, entities, stagemgr from ...core import entities as core_entities from ...config import manager as cfg_mgr from ...plugin import events +from ...platform.types import message as platform_message @stage.stage_class("ResponseWrapper") @@ -34,7 +34,7 @@ class ResponseWrapper(stage.PipelineStage): """ # 如果 resp_messages[-1] 已经是 MessageChain 了 - if isinstance(query.resp_messages[-1], mirai.MessageChain): + if isinstance(query.resp_messages[-1], platform_message.MessageChain): query.resp_message_chain.append(query.resp_messages[-1]) yield entities.StageProcessResult( @@ -45,19 +45,14 @@ class ResponseWrapper(stage.PipelineStage): else: if query.resp_messages[-1].role == 'command': - # query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content)) - query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] ')) + query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain(prefix_text='[bot] ')) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query ) elif query.resp_messages[-1].role == 'plugin': - # if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): - # query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content)) - # else: - # query.resp_message_chain.append(query.resp_messages[-1].content) - query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain()) + query.resp_message_chain.append(query.resp_messages[-1].get_content_platform_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -72,7 +67,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = '' if result.content: # 有内容 - reply_text = str(result.get_content_mirai_message_chain()) + reply_text = str(result.get_content_platform_message_chain()) # ============= 触发插件事件 =============== event_ctx = await self.ap.plugin_mgr.emit_event( @@ -96,11 +91,11 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: - query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) + query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply)) else: - query.resp_message_chain.append(result.get_content_mirai_message_chain()) + query.resp_message_chain.append(result.get_content_platform_message_chain()) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, @@ -113,7 +108,7 @@ class ResponseWrapper(stage.PipelineStage): reply_text = f'调用函数 {".".join(function_names)}...' - query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) + query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)])) if self.ap.platform_cfg.data['track-function-calls']: @@ -139,11 +134,11 @@ class ResponseWrapper(stage.PipelineStage): else: if event_ctx.event.reply is not None: - query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply)) + query.resp_message_chain.append(platform_message.MessageChain(event_ctx.event.reply)) else: - query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)])) + query.resp_message_chain.append(platform_message.MessageChain([platform_message.Plain(reply_text)])) yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index 4b159b7..7cf64a1 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -4,9 +4,10 @@ from __future__ import annotations import typing import abc -import mirai from ..core import app +from .types import message as platform_message +from .types import events as platform_events preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = [] @@ -55,28 +56,28 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta): self, target_type: str, target_id: str, - message: mirai.MessageChain + message: platform_message.MessageChain ): """主动发送消息 Args: target_type (str): 目标类型,`person`或`group` target_id (str): 目标ID - message (mirai.MessageChain): YiriMirai库的消息链 + message (platform.types.MessageChain): 消息链 """ raise NotImplementedError async def reply_message( self, - message_source: mirai.MessageEvent, - message: mirai.MessageChain, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, quote_origin: bool = False ): """回复消息 Args: - message_source (mirai.MessageEvent): YiriMirai消息源事件 - message (mirai.MessageChain): YiriMirai库的消息链 + message_source (platform.types.MessageEvent): 消息源事件 + message (platform.types.MessageChain): 消息链 quote_origin (bool, optional): 是否引用原消息. Defaults to False. """ raise NotImplementedError @@ -87,27 +88,27 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta): def register_listener( self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] + event_type: typing.Type[platform_message.Event], + callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None] ): """注册事件监听器 Args: - event_type (typing.Type[mirai.Event]): YiriMirai事件类型 - callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 + event_type (typing.Type[platform.types.Event]): 事件类型 + callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件 """ raise NotImplementedError def unregister_listener( self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, MessageSourceAdapter], None] + event_type: typing.Type[platform_message.Event], + callback: typing.Callable[[platform_message.Event, MessageSourceAdapter], None] ): """注销事件监听器 Args: - event_type (typing.Type[mirai.Event]): YiriMirai事件类型 - callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 + event_type (typing.Type[platform.types.Event]): 事件类型 + callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件 """ raise NotImplementedError @@ -127,26 +128,26 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta): class MessageConverter: """消息链转换器基类""" @staticmethod - def yiri2target(message_chain: mirai.MessageChain): - """将YiriMirai消息链转换为目标消息链 + def yiri2target(message_chain: platform_message.MessageChain): + """将源平台消息链转换为目标平台消息链 Args: - message_chain (mirai.MessageChain): YiriMirai消息链 + message_chain (platform.types.MessageChain): 源平台消息链 Returns: - typing.Any: 目标消息链 + typing.Any: 目标平台消息链 """ raise NotImplementedError @staticmethod - def target2yiri(message_chain: typing.Any) -> mirai.MessageChain: - """将目标消息链转换为YiriMirai消息链 + def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain: + """将目标平台消息链转换为源平台消息链 Args: - message_chain (typing.Any): 目标消息链 + message_chain (typing.Any): 目标平台消息链 Returns: - mirai.MessageChain: YiriMirai消息链 + platform.types.MessageChain: 源平台消息链 """ raise NotImplementedError @@ -155,25 +156,25 @@ class EventConverter: """事件转换器基类""" @staticmethod - def yiri2target(event: typing.Type[mirai.Event]): - """将YiriMirai事件转换为目标事件 + def yiri2target(event: typing.Type[platform_message.Event]): + """将源平台事件转换为目标平台事件 Args: - event (typing.Type[mirai.Event]): YiriMirai事件 + event (typing.Type[platform.types.Event]): 源平台事件 Returns: - typing.Any: 目标事件 + typing.Any: 目标平台事件 """ raise NotImplementedError @staticmethod - def target2yiri(event: typing.Any) -> mirai.Event: - """将目标事件的调用参数转换为YiriMirai的事件参数对象 + def target2yiri(event: typing.Any) -> platform_message.Event: + """将目标平台事件的调用参数转换为源平台的事件参数对象 Args: - event (typing.Any): 目标事件 + event (typing.Any): 目标平台事件 Returns: - typing.Type[mirai.Event]: YiriMirai事件 + typing.Type[platform.types.Event]: 源平台事件 """ raise NotImplementedError diff --git a/pkg/platform/manager.py b/pkg/platform/manager.py index b969642..aed8def 100644 --- a/pkg/platform/manager.py +++ b/pkg/platform/manager.py @@ -2,17 +2,24 @@ from __future__ import annotations import json import os +import sys import logging import asyncio import traceback -from mirai import At, GroupMessage, MessageEvent, StrangerMessage, \ - FriendMessage, Image, MessageChain, Plain -import mirai +# FriendMessage, Image, MessageChain, Plain from ..platform import adapter as msadapter from ..core import app, entities as core_entities from ..plugin import events +from .types import message as platform_message +from .types import events as platform_events +from .types import entities as platform_entities + +# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 +from . import types as mirai +sys.modules['mirai'] = mirai + # 控制QQ消息输入输出的类 class PlatformManager: @@ -32,7 +39,7 @@ class PlatformManager: from .sources import yirimirai, nakuru, aiocqhttp, qqbotpy - async def on_friend_message(event: FriendMessage, adapter: msadapter.MessageSourceAdapter): + async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PersonMessageReceived( @@ -55,7 +62,7 @@ class PlatformManager: adapter=adapter ) - async def on_stranger_message(event: StrangerMessage, adapter: msadapter.MessageSourceAdapter): + async def on_stranger_message(event: platform_events.StrangerMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( event=events.PersonMessageReceived( @@ -78,7 +85,7 @@ class PlatformManager: adapter=adapter ) - async def on_group_message(event: GroupMessage, adapter: msadapter.MessageSourceAdapter): + async def on_group_message(event: platform_events.GroupMessage, adapter: msadapter.MessageSourceAdapter): event_ctx = await self.ap.plugin_mgr.emit_event( event=events.GroupMessageReceived( @@ -127,16 +134,16 @@ class PlatformManager: if adapter_name == 'yiri-mirai': adapter_inst.register_listener( - StrangerMessage, + platform_events.StrangerMessage, on_stranger_message ) adapter_inst.register_listener( - FriendMessage, + platform_events.FriendMessage, on_friend_message ) adapter_inst.register_listener( - GroupMessage, + platform_events.GroupMessage, on_group_message ) @@ -146,13 +153,13 @@ class PlatformManager: if len(self.adapters) == 0: self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。') - async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter): + async def send(self, event: platform_events.MessageEvent, msg: platform_message.MessageChain, adapter: msadapter.MessageSourceAdapter): - if self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage): + if self.ap.platform_cfg.data['at-sender'] and isinstance(event, platform_events.GroupMessage): msg.insert( 0, - At( + platform_message.At( event.sender.id ) ) diff --git a/pkg/platform/sources/aiocqhttp.py b/pkg/platform/sources/aiocqhttp.py index bd1b067..25d197e 100644 --- a/pkg/platform/sources/aiocqhttp.py +++ b/pkg/platform/sources/aiocqhttp.py @@ -5,31 +5,32 @@ import traceback import time import datetime -import mirai -import mirai.models.message as yiri_message import aiocqhttp from .. import adapter from ...pipeline.longtext.strategies import forward from ...core import app +from ..types import message as platform_message +from ..types import events as platform_events +from ..types import entities as platform_entities class AiocqhttpMessageConverter(adapter.MessageConverter): @staticmethod - def yiri2target(message_chain: mirai.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: + def yiri2target(message_chain: platform_message.MessageChain) -> typing.Tuple[list, int, datetime.datetime]: msg_list = aiocqhttp.Message() msg_id = 0 msg_time = None for msg in message_chain: - if type(msg) is mirai.Plain: + if type(msg) is platform_message.Plain: msg_list.append(aiocqhttp.MessageSegment.text(msg.text)) - elif type(msg) is yiri_message.Source: + elif type(msg) is platform_message.Source: msg_id = msg.id msg_time = msg.time - elif type(msg) is mirai.Image: + elif type(msg) is platform_message.Image: arg = '' if msg.base64: arg = msg.base64 @@ -40,13 +41,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): elif msg.path: arg = msg.path msg_list.append(aiocqhttp.MessageSegment.image(arg)) - elif type(msg) is mirai.At: + elif type(msg) is platform_message.At: msg_list.append(aiocqhttp.MessageSegment.at(msg.target)) - elif type(msg) is mirai.AtAll: + elif type(msg) is platform_message.AtAll: msg_list.append(aiocqhttp.MessageSegment.at("all")) - elif type(msg) is mirai.Face: - msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) - elif type(msg) is mirai.Voice: + elif type(msg) is platform_message.Voice: arg = '' if msg.base64: arg = msg.base64 @@ -74,25 +73,25 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): yiri_msg_list = [] yiri_msg_list.append( - yiri_message.Source(id=message_id, time=datetime.datetime.now()) + platform_message.Source(id=message_id, time=datetime.datetime.now()) ) for msg in message: if msg.type == "at": if msg.data["qq"] == "all": - yiri_msg_list.append(yiri_message.AtAll()) + yiri_msg_list.append(platform_message.AtAll()) else: yiri_msg_list.append( - yiri_message.At( + platform_message.At( target=msg.data["qq"], ) ) elif msg.type == "text": - yiri_msg_list.append(yiri_message.Plain(text=msg.data["text"])) + yiri_msg_list.append(platform_message.Plain(text=msg.data["text"])) elif msg.type == "image": - yiri_msg_list.append(yiri_message.Image(url=msg.data["url"])) + yiri_msg_list.append(platform_message.Image(url=msg.data["url"])) - chain = mirai.MessageChain(yiri_msg_list) + chain = platform_message.MessageChain(yiri_msg_list) return chain @@ -100,11 +99,11 @@ class AiocqhttpMessageConverter(adapter.MessageConverter): class AiocqhttpEventConverter(adapter.EventConverter): @staticmethod - def yiri2target(event: mirai.Event, bot_account_id: int): + def yiri2target(event: platform_events.Event, bot_account_id: int): msg, msg_id, msg_time = AiocqhttpMessageConverter.yiri2target(event.message_chain) - if type(event) is mirai.GroupMessage: + if type(event) is platform_events.GroupMessage: role = "member" if event.sender.permission == "ADMINISTRATOR": @@ -140,7 +139,7 @@ class AiocqhttpEventConverter(adapter.EventConverter): } return aiocqhttp.Event.from_payload(payload) - elif type(event) is mirai.FriendMessage: + elif type(event) is platform_events.FriendMessage: payload = { "post_type": "message", @@ -178,15 +177,15 @@ class AiocqhttpEventConverter(adapter.EventConverter): permission = "ADMINISTRATOR" elif event.sender["role"] == "owner": permission = "OWNER" - converted_event = mirai.GroupMessage( - sender=mirai.models.entities.GroupMember( + converted_event = platform_events.GroupMessage( + sender=platform_entities.GroupMember( id=event.sender["user_id"], # message_seq 放哪? member_name=event.sender["nickname"], permission=permission, - group=mirai.models.entities.Group( + group=platform_entities.Group( id=event.group_id, name=event.sender["nickname"], - permission=mirai.models.entities.Permission.Member, + permission=platform_entities.Permission.Member, ), special_title=event.sender["title"] if "title" in event.sender else "", join_timestamp=0, @@ -198,8 +197,8 @@ class AiocqhttpEventConverter(adapter.EventConverter): ) return converted_event elif event.message_type == "private": - return mirai.FriendMessage( - sender=mirai.models.entities.Friend( + return platform_events.FriendMessage( + sender=platform_entities.Friend( id=event.sender["user_id"], nickname=event.sender["nickname"], remark="", @@ -241,7 +240,7 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter): self.bot = aiocqhttp.CQHttp() async def send_message( - self, target_type: str, target_id: str, message: mirai.MessageChain + self, target_type: str, target_id: str, message: platform_message.MessageChain ): aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0] @@ -252,8 +251,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter): async def reply_message( self, - message_source: mirai.MessageEvent, - message: mirai.MessageChain, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, quote_origin: bool = False, ): aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id) @@ -271,8 +270,8 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter): def register_listener( self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None], ): async def on_message(event: aiocqhttp.Event): self.bot_account_id = event.self_id @@ -281,15 +280,15 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter): except: traceback.print_exc() - if event_type == mirai.GroupMessage: + if event_type == platform_events.GroupMessage: self.bot.on_message("group")(on_message) - elif event_type == mirai.FriendMessage: + elif event_type == platform_events.FriendMessage: self.bot.on_message("private")(on_message) def unregister_listener( self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, adapter.MessageSourceAdapter], None], + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter.MessageSourceAdapter], None], ): return super().unregister_listener(event_type, callback) diff --git a/pkg/platform/sources/nakuru.py b/pkg/platform/sources/nakuru.py index 94c2981..2fbe8be 100644 --- a/pkg/platform/sources/nakuru.py +++ b/pkg/platform/sources/nakuru.py @@ -6,26 +6,28 @@ import typing import traceback import logging -import mirai import nakuru import nakuru.entities.components as nkc from .. import adapter as adapter_model from ...pipeline.longtext.strategies import forward +from ...platform.types import message as platform_message +from ...platform.types import entities as platform_entities +from ...platform.types import events as platform_events class NakuruProjectMessageConverter(adapter_model.MessageConverter): """消息转换器""" @staticmethod - def yiri2target(message_chain: mirai.MessageChain) -> list: + def yiri2target(message_chain: platform_message.MessageChain) -> list: msg_list = [] - if type(message_chain) is mirai.MessageChain: + if type(message_chain) is platform_message.MessageChain: msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain elif type(message_chain) is str: - msg_list = [mirai.Plain(message_chain)] + msg_list = [platform_message.Plain(message_chain)] else: raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain))) @@ -33,22 +35,20 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): # 遍历并转换 for component in msg_list: - if type(component) is mirai.Plain: + if type(component) is platform_message.Plain: nakuru_msg_list.append(nkc.Plain(component.text, False)) - elif type(component) is mirai.Image: + elif type(component) is platform_message.Image: if component.url is not None: nakuru_msg_list.append(nkc.Image.fromURL(component.url)) elif component.base64 is not None: nakuru_msg_list.append(nkc.Image.fromBase64(component.base64)) elif component.path is not None: nakuru_msg_list.append(nkc.Image.fromFileSystem(component.path)) - elif type(component) is mirai.Face: - nakuru_msg_list.append(nkc.Face(id=component.face_id)) - elif type(component) is mirai.At: + elif type(component) is platform_message.At: nakuru_msg_list.append(nkc.At(qq=component.target)) - elif type(component) is mirai.AtAll: + elif type(component) is platform_message.AtAll: nakuru_msg_list.append(nkc.AtAll()) - elif type(component) is mirai.Voice: + elif type(component) is platform_message.Voice: if component.url is not None: nakuru_msg_list.append(nkc.Record.fromURL(component.url)) elif component.path is not None: @@ -80,49 +80,47 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter): return nakuru_msg_list @staticmethod - def target2yiri(message_chain: typing.Any, message_id: int = -1) -> mirai.MessageChain: + def target2yiri(message_chain: typing.Any, message_id: int = -1) -> platform_message.MessageChain: """将Yiri的消息链转换为YiriMirai的消息链""" assert type(message_chain) is list yiri_msg_list = [] import datetime # 添加Source组件以标记message_id等信息 - yiri_msg_list.append(mirai.models.message.Source(id=message_id, time=datetime.datetime.now())) + yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) for component in message_chain: if type(component) is nkc.Plain: - yiri_msg_list.append(mirai.Plain(text=component.text)) + yiri_msg_list.append(platform_message.Plain(text=component.text)) elif type(component) is nkc.Image: - yiri_msg_list.append(mirai.Image(url=component.url)) - elif type(component) is nkc.Face: - yiri_msg_list.append(mirai.Face(face_id=component.id)) + yiri_msg_list.append(platform_message.Image(url=component.url)) elif type(component) is nkc.At: - yiri_msg_list.append(mirai.At(target=component.qq)) + yiri_msg_list.append(platform_message.At(target=component.qq)) elif type(component) is nkc.AtAll: - yiri_msg_list.append(mirai.AtAll()) + yiri_msg_list.append(platform_message.AtAll()) else: pass # logging.debug("转换后的消息链: " + str(yiri_msg_list)) - chain = mirai.MessageChain(yiri_msg_list) + chain = platform_message.MessageChain(yiri_msg_list) return chain class NakuruProjectEventConverter(adapter_model.EventConverter): """事件转换器""" @staticmethod - def yiri2target(event: typing.Type[mirai.Event]): - if event is mirai.GroupMessage: + def yiri2target(event: typing.Type[platform_events.Event]): + if event is platform_events.GroupMessage: return nakuru.GroupMessage - elif event is mirai.FriendMessage: + elif event is platform_events.FriendMessage: return nakuru.FriendMessage else: raise Exception("未支持转换的事件类型: " + str(event)) @staticmethod - def target2yiri(event: typing.Any) -> mirai.Event: + def target2yiri(event: typing.Any) -> platform_events.Event: yiri_chain = NakuruProjectMessageConverter.target2yiri(event.message, event.message_id) if type(event) is nakuru.FriendMessage: # 私聊消息事件 - return mirai.FriendMessage( - sender=mirai.models.entities.Friend( + return platform_events.FriendMessage( + sender=platform_entities.Friend( id=event.sender.user_id, nickname=event.sender.nickname, remark=event.sender.nickname @@ -138,16 +136,15 @@ class NakuruProjectEventConverter(adapter_model.EventConverter): elif event.sender.role == "owner": permission = "OWNER" - import mirai.models.entities as entities - return mirai.GroupMessage( - sender=mirai.models.entities.GroupMember( + return platform_events.GroupMessage( + sender=platform_entities.GroupMember( id=event.sender.user_id, member_name=event.sender.nickname, permission=permission, - group=mirai.models.entities.Group( + group=platform_entities.Group( id=event.group_id, name=event.sender.nickname, - permission=entities.Permission.Member + permission=platform_entities.Permission.Member ), special_title=event.sender.title, join_timestamp=0, @@ -189,7 +186,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): self, target_type: str, target_id: str, - message: typing.Union[mirai.MessageChain, list], + message: typing.Union[platform_message.MessageChain, list], converted: bool = False ): task = None @@ -222,8 +219,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): async def reply_message( self, - message_source: mirai.MessageEvent, - message: mirai.MessageChain, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, quote_origin: bool = False ): message = self.message_converter.yiri2target(message) @@ -233,14 +230,14 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): id=message_source.message_chain.message_id, ) ) - if type(message_source) is mirai.GroupMessage: + if type(message_source) is platform_events.GroupMessage: await self.send_message( "group", message_source.sender.group.id, message, converted=True ) - elif type(message_source) is mirai.FriendMessage: + elif type(message_source) is platform_events.FriendMessage: await self.send_message( "person", message_source.sender.id, @@ -258,8 +255,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): def register_listener( self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None] ): try: @@ -286,8 +283,8 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter): def unregister_listener( self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] + event_type: typing.Type[platform_events.Event], + callback: typing.Callable[[platform_events.Event, adapter_model.MessageSourceAdapter], None] ): nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ diff --git a/pkg/platform/sources/qqbotpy.py b/pkg/platform/sources/qqbotpy.py index a79013c..cbc86f4 100644 --- a/pkg/platform/sources/qqbotpy.py +++ b/pkg/platform/sources/qqbotpy.py @@ -6,7 +6,6 @@ import datetime import re import traceback -import mirai import botpy import botpy.message as botpy_message import botpy.types.message as botpy_message_type @@ -17,17 +16,21 @@ from .. import adapter as adapter_model from ...pipeline.longtext.strategies import forward from ...core import app from ...config import manager as cfg_mgr +from ...platform.types import entities as platform_entities +from ...platform.types import events as platform_events +from ...platform.types import message as platform_message -class OfficialGroupMessage(mirai.GroupMessage): + +class OfficialGroupMessage(platform_events.GroupMessage): pass -class OfficialFriendMessage(mirai.FriendMessage): +class OfficialFriendMessage(platform_events.FriendMessage): pass event_handler_mapping = { - mirai.GroupMessage: ["on_at_message_create", "on_group_at_message_create"], - mirai.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"], + platform_events.GroupMessage: ["on_at_message_create", "on_group_at_message_create"], + platform_events.FriendMessage: ["on_direct_message_create", "on_c2c_message_create"], } @@ -123,16 +126,16 @@ class OfficialMessageConverter(adapter_model.MessageConverter): """QQ 官方消息转换器""" @staticmethod - def yiri2target(message_chain: mirai.MessageChain): + def yiri2target(message_chain: platform_message.MessageChain): """将 YiriMirai 的消息链转换为 QQ 官方消息""" msg_list = [] - if type(message_chain) is mirai.MessageChain: + if type(message_chain) is platform_message.MessageChain: msg_list = message_chain.__root__ elif type(message_chain) is list: msg_list = message_chain elif type(message_chain) is str: - msg_list = [mirai.Plain(text=message_chain)] + msg_list = [platform_message.Plain(text=message_chain)] else: raise Exception( "Unknown message type: " + str(message_chain) + str(type(message_chain)) @@ -153,22 +156,22 @@ class OfficialMessageConverter(adapter_model.MessageConverter): # 遍历并转换 for component in msg_list: - if type(component) is mirai.Plain: + if type(component) is platform_message.Plain: offcial_messages.append({"type": "text", "content": component.text}) - elif type(component) is mirai.Image: + elif type(component) is platform_message.Image: if component.url is not None: offcial_messages.append({"type": "image", "content": component.url}) elif component.path is not None: offcial_messages.append( {"type": "file_image", "content": component.path} ) - elif type(component) is mirai.At: + elif type(component) is platform_message.At: offcial_messages.append({"type": "at", "content": ""}) - elif type(component) is mirai.AtAll: + elif type(component) is platform_message.AtAll: print( "上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" ) - elif type(component) is mirai.Voice: + elif type(component) is platform_message.Voice: print( "上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。" ) @@ -197,29 +200,29 @@ class OfficialMessageConverter(adapter_model.MessageConverter): message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage], message_id: str = None, bot_account_id: int = 0, - ) -> mirai.MessageChain: + ) -> platform_message.MessageChain: yiri_msg_list = [] # 存id yiri_msg_list.append( - mirai.models.message.Source( + platform_message.Source( id=save_msg_id(message_id), time=datetime.datetime.now() ) ) if type(message) not in [botpy_message.DirectMessage, botpy_message.C2CMessage]: - yiri_msg_list.append(mirai.At(target=bot_account_id)) + yiri_msg_list.append(platform_message.At(target=bot_account_id)) if hasattr(message, "mentions"): for mention in message.mentions: if mention.bot: continue - yiri_msg_list.append(mirai.At(target=mention.id)) + yiri_msg_list.append(platform_message.At(target=mention.id)) for attachment in message.attachments: if attachment.content_type.startswith("image"): - yiri_msg_list.append(mirai.Image(url=attachment.url)) + yiri_msg_list.append(platform_message.Image(url=attachment.url)) else: logging.warning( "不支持的附件类型:" + attachment.content_type + ",忽略此附件。" @@ -227,9 +230,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter): content = re.sub(r"<@!\d+>", "", str(message.content)) if content.strip() != "": - yiri_msg_list.append(mirai.Plain(text=content)) + yiri_msg_list.append(platform_message.Plain(text=content)) - chain = mirai.MessageChain(yiri_msg_list) + chain = platform_message.MessageChain(yiri_msg_list) return chain @@ -244,10 +247,10 @@ class OfficialEventConverter(adapter_model.EventConverter): self.member_openid_mapping = member_openid_mapping self.group_openid_mapping = group_openid_mapping - def yiri2target(self, event: typing.Type[mirai.Event]): - if event == mirai.GroupMessage: + def yiri2target(self, event: typing.Type[platform_events.Event]): + if event == platform_events.GroupMessage: return botpy_message.Message - elif event == mirai.FriendMessage: + elif event == platform_events.FriendMessage: return botpy_message.DirectMessage else: raise Exception( @@ -257,8 +260,7 @@ class OfficialEventConverter(adapter_model.EventConverter): def target2yiri( self, event: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage, botpy_message.C2CMessage], - ) -> mirai.Event: - import mirai.models.entities as mirai_entities + ) -> platform_events.Event: if type(event) == botpy_message.Message: # 频道内,转群聊事件 permission = "MEMBER" @@ -268,15 +270,15 @@ class OfficialEventConverter(adapter_model.EventConverter): elif "4" in event.member.roles: permission = "OWNER" - return mirai.GroupMessage( - sender=mirai_entities.GroupMember( + return platform_events.GroupMessage( + sender=platform_entities.GroupMember( id=event.author.id, member_name=event.author.username, permission=permission, - group=mirai_entities.Group( + group=platform_entities.Group( id=event.channel_id, name=event.author.username, - permission=mirai_entities.Permission.Member, + permission=platform_entities.Permission.Member, ), special_title="", join_timestamp=int( @@ -297,8 +299,8 @@ class OfficialEventConverter(adapter_model.EventConverter): ), ) elif type(event) == botpy_message.DirectMessage: # 频道私聊,转私聊事件 - return mirai.FriendMessage( - sender=mirai_entities.Friend( + return platform_events.FriendMessage( + sender=platform_entities.Friend( id=event.guild_id, nickname=event.author.username, remark=event.author.username, @@ -317,14 +319,14 @@ class OfficialEventConverter(adapter_model.EventConverter): replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid) return OfficialGroupMessage( - sender=mirai_entities.GroupMember( + sender=platform_entities.GroupMember( id=replacing_member_id, member_name=replacing_member_id, permission="MEMBER", - group=mirai_entities.Group( + group=platform_entities.Group( id=self.group_openid_mapping.save_openid(event.group_openid), name=replacing_member_id, - permission=mirai_entities.Permission.Member, + permission=platform_entities.Permission.Member, ), special_title="", join_timestamp=int(0), @@ -345,7 +347,7 @@ class OfficialEventConverter(adapter_model.EventConverter): user_id_alter = self.member_openid_mapping.save_openid(event.author.user_openid) # 实测这里的user_openid与group的member_openid是一样的 return OfficialFriendMessage( - sender=mirai_entities.Friend( + sender=platform_entities.Friend( id=user_id_alter, nickname=user_id_alter, remark=user_id_alter, @@ -410,7 +412,7 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter): self.bot = botpy.Client(intents=intents) async def send_message( - self, target_type: str, target_id: str, message: mirai.MessageChain + self, target_type: str, target_id: str, message: platform_message.MessageChain ): message_list = self.message_converter.yiri2target(message) @@ -437,8 +439,8 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter): async def reply_message( self, - message_source: mirai.MessageEvent, - message: mirai.MessageChain, + message_source: platform_events.MessageEvent, + message: platform_message.MessageChain, quote_origin: bool = False, ): @@ -463,13 +465,13 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter): ] ) - if type(message_source) == mirai.GroupMessage: + if type(message_source) == platform_events.GroupMessage: args["channel_id"] = str(message_source.sender.group.id) args["msg_id"] = cached_message_ids[ str(message_source.message_chain.message_id) ] await self.bot.api.post_message(**args) - elif type(message_source) == mirai.FriendMessage: + elif type(message_source) == platform_events.FriendMessage: args["guild_id"] = str(message_source.sender.id) args["msg_id"] = cached_message_ids[ str(message_source.message_chain.message_id) @@ -534,9 +536,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter): def register_listener( self, - event_type: typing.Type[mirai.Event], + event_type: typing.Type[platform_events.Event], callback: typing.Callable[ - [mirai.Event, adapter_model.MessageSourceAdapter], None + [platform_events.Event, adapter_model.MessageSourceAdapter], None ], ): @@ -560,9 +562,9 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter): def unregister_listener( self, - event_type: typing.Type[mirai.Event], + event_type: typing.Type[platform_events.Event], callback: typing.Callable[ - [mirai.Event, adapter_model.MessageSourceAdapter], None + [platform_events.Event, adapter_model.MessageSourceAdapter], None ], ): delattr(self.bot, event_handler_mapping[event_type]) diff --git a/pkg/platform/sources/yirimirai.py b/pkg/platform/sources/yirimirai.py index 7768dcf..aa0823f 100644 --- a/pkg/platform/sources/yirimirai.py +++ b/pkg/platform/sources/yirimirai.py @@ -1,124 +1,121 @@ -import asyncio -import typing - -import mirai -import mirai.models.bus -from mirai.bot import MiraiRunner - -from .. import adapter as adapter_model -from ...core import app +# import asyncio +# import typing -@adapter_model.adapter_class("yiri-mirai") -class YiriMiraiAdapter(adapter_model.MessageSourceAdapter): - """YiriMirai适配器""" - bot: mirai.Mirai +# from .. import adapter as adapter_model +# from ...core import app - def __init__(self, config: dict, ap: app.Application): - """初始化YiriMirai的对象""" - self.ap = ap - self.config = config - if 'adapter' not in config or \ - config['adapter'] == 'WebSocketAdapter': - self.bot = mirai.Mirai( - qq=config['qq'], - adapter=mirai.WebSocketAdapter( - host=config['host'], - port=config['port'], - verify_key=config['verifyKey'] - ) - ) - elif config['adapter'] == 'HTTPAdapter': - self.bot = mirai.Mirai( - qq=config['qq'], - adapter=mirai.HTTPAdapter( - host=config['host'], - port=config['port'], - verify_key=config['verifyKey'] - ) - ) - else: - raise Exception('Unknown adapter for YiriMirai: ' + config['adapter']) + +# @adapter_model.adapter_class("yiri-mirai") +# class YiriMiraiAdapter(adapter_model.MessageSourceAdapter): +# """YiriMirai适配器""" +# bot: mirai.Mirai + +# def __init__(self, config: dict, ap: app.Application): +# """初始化YiriMirai的对象""" +# self.ap = ap +# self.config = config +# if 'adapter' not in config or \ +# config['adapter'] == 'WebSocketAdapter': +# self.bot = mirai.Mirai( +# qq=config['qq'], +# adapter=mirai.WebSocketAdapter( +# host=config['host'], +# port=config['port'], +# verify_key=config['verifyKey'] +# ) +# ) +# elif config['adapter'] == 'HTTPAdapter': +# self.bot = mirai.Mirai( +# qq=config['qq'], +# adapter=mirai.HTTPAdapter( +# host=config['host'], +# port=config['port'], +# verify_key=config['verifyKey'] +# ) +# ) +# else: +# raise Exception('Unknown adapter for YiriMirai: ' + config['adapter']) - async def send_message( - self, - target_type: str, - target_id: str, - message: mirai.MessageChain - ): - """发送消息 +# async def send_message( +# self, +# target_type: str, +# target_id: str, +# message: mirai.MessageChain +# ): +# """发送消息 - Args: - target_type (str): 目标类型,`person`或`group` - target_id (str): 目标ID - message (mirai.MessageChain): YiriMirai库的消息链 - """ - task = None - if target_type == 'person': - task = self.bot.send_friend_message(int(target_id), message) - elif target_type == 'group': - task = self.bot.send_group_message(int(target_id), message) - else: - raise Exception('Unknown target type: ' + target_type) +# Args: +# target_type (str): 目标类型,`person`或`group` +# target_id (str): 目标ID +# message (mirai.MessageChain): YiriMirai库的消息链 +# """ +# task = None +# if target_type == 'person': +# task = self.bot.send_friend_message(int(target_id), message) +# elif target_type == 'group': +# task = self.bot.send_group_message(int(target_id), message) +# else: +# raise Exception('Unknown target type: ' + target_type) - await task +# await task - async def reply_message( - self, - message_source: mirai.MessageEvent, - message: mirai.MessageChain, - quote_origin: bool = False - ): - """回复消息 +# async def reply_message( +# self, +# message_source: mirai.MessageEvent, +# message: mirai.MessageChain, +# quote_origin: bool = False +# ): +# """回复消息 - Args: - message_source (mirai.MessageEvent): YiriMirai消息源事件 - message (mirai.MessageChain): YiriMirai库的消息链 - quote_origin (bool, optional): 是否引用原消息. Defaults to False. - """ - await self.bot.send(message_source, message, quote_origin) +# Args: +# message_source (mirai.MessageEvent): YiriMirai消息源事件 +# message (mirai.MessageChain): YiriMirai库的消息链 +# quote_origin (bool, optional): 是否引用原消息. Defaults to False. +# """ +# await self.bot.send(message_source, message, quote_origin) - async def is_muted(self, group_id: int) -> bool: - result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get() - if result.mute_time_remaining > 0: - return True - return False +# async def is_muted(self, group_id: int) -> bool: +# result = await self.bot.member_info(target=group_id, member_id=self.bot.qq).get() +# if result.mute_time_remaining > 0: +# return True +# return False - def register_listener( - self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] - ): - """注册事件监听器 +# def register_listener( +# self, +# event_type: typing.Type[mirai.Event], +# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] +# ): +# """注册事件监听器 - Args: - event_type (typing.Type[mirai.Event]): YiriMirai事件类型 - callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 - """ - async def wrapper(event: mirai.Event): - await callback(event, self) - self.bot.on(event_type)(wrapper) +# Args: +# event_type (typing.Type[mirai.Event]): YiriMirai事件类型 +# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 +# """ +# async def wrapper(event: mirai.Event): +# await callback(event, self) +# self.bot.on(event_type)(wrapper) - def unregister_listener( - self, - event_type: typing.Type[mirai.Event], - callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] - ): - """注销事件监听器 +# def unregister_listener( +# self, +# event_type: typing.Type[mirai.Event], +# callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None] +# ): +# """注销事件监听器 - Args: - event_type (typing.Type[mirai.Event]): YiriMirai事件类型 - callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 - """ - assert isinstance(self.bot, mirai.Mirai) - bus = self.bot.bus - assert isinstance(bus, mirai.models.bus.ModelEventBus) +# Args: +# event_type (typing.Type[mirai.Event]): YiriMirai事件类型 +# callback (typing.Callable[[mirai.Event], None]): 回调函数,接收一个参数,为YiriMirai事件 +# """ +# assert isinstance(self.bot, mirai.Mirai) +# bus = self.bot.bus +# assert isinstance(bus, mirai.models.bus.ModelEventBus) - bus.unsubscribe(event_type, callback) +# bus.unsubscribe(event_type, callback) - async def run_async(self): - self.bot_account_id = self.bot.qq - return await MiraiRunner(self.bot)._run() +# async def run_async(self): +# self.bot_account_id = self.bot.qq +# return await MiraiRunner(self.bot)._run() - async def kill(self) -> bool: - return False +# async def kill(self) -> bool: +# return False diff --git a/pkg/platform/types/__init__.py b/pkg/platform/types/__init__.py new file mode 100644 index 0000000..998b0fb --- /dev/null +++ b/pkg/platform/types/__init__.py @@ -0,0 +1,3 @@ +from .entities import * +from .events import * +from .message import * diff --git a/pkg/platform/types/base.py b/pkg/platform/types/base.py new file mode 100644 index 0000000..d3c0be4 --- /dev/null +++ b/pkg/platform/types/base.py @@ -0,0 +1,105 @@ + +from typing import Dict, List, Type + +import pydantic.main as pdm +from pydantic import BaseModel + + +class PlatformMetaclass(pdm.ModelMetaclass): + """此类是平台中使用的 pydantic 模型的元类的基类。""" + + +def to_camel(name: str) -> str: + """将下划线命名风格转换为小驼峰命名。""" + if name[:2] == '__': # 不处理双下划线开头的特殊命名。 + return name + name_parts = name.split('_') + return ''.join(name_parts[:1] + [x.title() for x in name_parts[1:]]) + + +class PlatformBaseModel(BaseModel, metaclass=PlatformMetaclass): + """模型基类。 + + 启用了三项配置: + 1. 允许解析时传入额外的值,并将额外值保存在模型中。 + 2. 允许通过别名访问字段。 + 3. 自动生成小驼峰风格的别名。 + """ + def __init__(self, *args, **kwargs): + """""" + super().__init__(*args, **kwargs) + + def __repr__(self) -> str: + return self.__class__.__name__ + '(' + ', '.join( + (f'{k}={repr(v)}' for k, v in self.__dict__.items() if v) + ) + ')' + + class Config: + extra = 'allow' + allow_population_by_field_name = True + alias_generator = to_camel + + +class PlatformIndexedMetaclass(PlatformMetaclass): + """可以通过子类名获取子类的类的元类。""" + __indexedbases__: List[Type['PlatformIndexedModel']] = [] + __indexedmodel__ = None + + def __new__(cls, name, bases, attrs, **kwargs): + new_cls = super().__new__(cls, name, bases, attrs, **kwargs) + # 第一类:PlatformIndexedModel + if name == 'PlatformIndexedModel': + cls.__indexedmodel__ = new_cls + new_cls.__indexes__ = {} + return new_cls + # 第二类:PlatformIndexedModel 的直接子类,这些是可以通过子类名获取子类的类。 + if cls.__indexedmodel__ in bases: + cls.__indexedbases__.append(new_cls) + new_cls.__indexes__ = {} + return new_cls + # 第三类:PlatformIndexedModel 的直接子类的子类,这些添加到直接子类的索引中。 + for base in cls.__indexedbases__: + if issubclass(new_cls, base): + base.__indexes__[name] = new_cls + return new_cls + + def __getitem__(cls, name): + return cls.get_subtype(name) + + +class PlatformIndexedModel(PlatformBaseModel, metaclass=PlatformIndexedMetaclass): + """可以通过子类名获取子类的类。""" + __indexes__: Dict[str, Type['PlatformIndexedModel']] + + @classmethod + def get_subtype(cls, name: str) -> Type['PlatformIndexedModel']: + """根据类名称,获取相应的子类类型。 + + Args: + name: 类名称。 + + Returns: + Type['PlatformIndexedModel']: 子类类型。 + """ + try: + type_ = cls.__indexes__.get(name) + if not (type_ and issubclass(type_, cls)): + raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') + return type_ + except AttributeError as e: + raise ValueError(f'`{name}` 不是 `{cls.__name__}` 的子类!') from None + + @classmethod + def parse_subtype(cls, obj: dict) -> 'PlatformIndexedModel': + """通过字典,构造对应的模型对象。 + + Args: + obj: 一个字典,包含了模型对象的属性。 + + Returns: + PlatformIndexedModel: 构造的对象。 + """ + if cls in PlatformIndexedModel.__subclasses__(): + ModelType = cls.get_subtype(obj['type']) + return ModelType.parse_obj(obj) + return super().parse_obj(obj) diff --git a/pkg/platform/types/entities.py b/pkg/platform/types/entities.py new file mode 100644 index 0000000..8077bd1 --- /dev/null +++ b/pkg/platform/types/entities.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +""" +此模块提供实体和配置项模型。 +""" +import abc +from datetime import datetime +from enum import Enum +import typing + +import pydantic + + +class Entity(pydantic.BaseModel): + """实体,表示一个用户或群。""" + id: int + """QQ 号或群号。""" + @abc.abstractmethod + def get_avatar_url(self) -> str: + """头像图片链接。""" + + @abc.abstractmethod + def get_name(self) -> str: + """名称。""" + + +class Friend(Entity): + """好友。""" + id: int + """QQ 号。""" + nickname: typing.Optional[str] + """昵称。""" + remark: typing.Optional[str] + """备注。""" + def get_avatar_url(self) -> str: + return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140' + + def get_name(self) -> str: + return self.nickname or self.remark or '' + + +class Permission(str, Enum): + """群成员身份权限。""" + Member = "MEMBER" + """成员。""" + Administrator = "ADMINISTRATOR" + """管理员。""" + Owner = "OWNER" + """群主。""" + def __repr__(self) -> str: + return repr(self.value) + + +class Group(Entity): + """群。""" + id: int + """群号。""" + name: str + """群名称。""" + permission: Permission + """Bot 在群中的权限。""" + def get_avatar_url(self) -> str: + return f'https://p.qlogo.cn/gh/{self.id}/{self.id}/' + + def get_name(self) -> str: + return self.name + + +class GroupMember(Entity): + """群成员。""" + id: int + """QQ 号。""" + member_name: str + """群成员名称。""" + permission: Permission + """Bot 在群中的权限。""" + group: Group + """群。""" + special_title: str = '' + """群头衔。""" + join_timestamp: datetime = datetime.utcfromtimestamp(0) + """加入群的时间。""" + last_speak_timestamp: datetime = datetime.utcfromtimestamp(0) + """最后一次发言的时间。""" + mute_time_remaining: int = 0 + """禁言剩余时间。""" + def get_avatar_url(self) -> str: + return f'http://q4.qlogo.cn/g?b=qq&nk={self.id}&s=140' + + def get_name(self) -> str: + return self.member_name + + +class Client(Entity): + """来自其他客户端的用户。""" + id: int + """识别 id。""" + platform: str + """来源平台。""" + def get_avatar_url(self) -> str: + raise NotImplementedError + + def get_name(self) -> str: + return self.platform + + +class Subject(pydantic.BaseModel): + """另一种实体类型表示。""" + id: int + """QQ 号或群号。""" + kind: typing.Literal['Friend', 'Group', 'Stranger'] + """类型。""" + + +class Config(pydantic.BaseModel): + """配置项类型。""" + def modify(self, **kwargs) -> 'Config': + """修改部分设置。""" + for k, v in kwargs.items(): + if k in self.__fields__: + setattr(self, k, v) + else: + raise ValueError(f'未知配置项: {k}') + return self + + +class GroupConfigModel(Config): + """群配置。""" + name: str + """群名称。""" + confess_talk: bool + """是否允许坦白说。""" + allow_member_invite: bool + """是否允许成员邀请好友入群。""" + auto_approve: bool + """是否开启自动审批入群。""" + anonymous_chat: bool + """是否开启匿名聊天。""" + announcement: str = '' + """群公告。""" + + +class MemberInfoModel(Config, GroupMember): + """群成员信息。""" diff --git a/pkg/platform/types/events.py b/pkg/platform/types/events.py new file mode 100644 index 0000000..1b008cf --- /dev/null +++ b/pkg/platform/types/events.py @@ -0,0 +1,124 @@ +# -*- coding: utf-8 -*- +""" +此模块提供事件模型。 +""" +from datetime import datetime +from enum import Enum +import typing + +import pydantic + +from . import entities as platform_entities +from . import message as platform_message + + +class Event(pydantic.BaseModel): + """事件基类。 + + Args: + type: 事件名。 + """ + type: str + """事件名。""" + def __repr__(self): + return self.__class__.__name__ + '(' + ', '.join( + ( + f'{k}={repr(v)}' + for k, v in self.__dict__.items() if k != 'type' and v + ) + ) + ')' + + @classmethod + def parse_subtype(cls, obj: dict) -> 'Event': + try: + return typing.cast(Event, super().parse_subtype(obj)) + except ValueError: + return Event(type=obj['type']) + + @classmethod + def get_subtype(cls, name: str) -> typing.Type['Event']: + try: + return typing.cast(typing.Type[Event], super().get_subtype(name)) + except ValueError: + return Event + + +############################### +# Bot Event +class BotEvent(Event): + """Bot 自身事件。 + + Args: + type: 事件名。 + qq: Bot 的 QQ 号。 + """ + type: str + """事件名。""" + qq: int + """Bot 的 QQ 号。""" + + +############################### +# Message Event +class MessageEvent(Event): + """消息事件。 + + Args: + type: 事件名。 + message_chain: 消息内容。 + """ + type: str + """事件名。""" + message_chain: platform_message.MessageChain + """消息内容。""" + + +class FriendMessage(MessageEvent): + """好友消息。 + + Args: + type: 事件名。 + sender: 发送消息的好友。 + message_chain: 消息内容。 + """ + type: str = 'FriendMessage' + """事件名。""" + sender: platform_entities.Friend + """发送消息的好友。""" + message_chain: platform_message.MessageChain + """消息内容。""" + + +class GroupMessage(MessageEvent): + """群消息。 + + Args: + type: 事件名。 + sender: 发送消息的群成员。 + message_chain: 消息内容。 + """ + type: str = 'GroupMessage' + """事件名。""" + sender: platform_entities.GroupMember + """发送消息的群成员。""" + message_chain: platform_message.MessageChain + """消息内容。""" + @property + def group(self) -> platform_entities.Group: + return self.sender.group + + +class StrangerMessage(MessageEvent): + """陌生人消息。 + + Args: + type: 事件名。 + sender: 发送消息的人。 + message_chain: 消息内容。 + """ + type: str = 'StrangerMessage' + """事件名。""" + sender: platform_entities.Friend + """发送消息的人。""" + message_chain: platform_message.MessageChain + """消息内容。""" diff --git a/pkg/platform/types/message.py b/pkg/platform/types/message.py new file mode 100644 index 0000000..e790f88 --- /dev/null +++ b/pkg/platform/types/message.py @@ -0,0 +1,817 @@ +import itertools +import logging +from datetime import datetime +from enum import Enum +from pathlib import Path +import typing + +import pydantic +import pydantic.main + +from . import entities as platform_entities +from .base import PlatformBaseModel, PlatformIndexedMetaclass, PlatformIndexedModel + + +logger = logging.getLogger(__name__) + + +class MessageComponentMetaclass(PlatformIndexedMetaclass): + """消息组件元类。""" + __message_component__ = None + + def __new__(cls, name, bases, attrs, **kwargs): + new_cls = super().__new__(cls, name, bases, attrs, **kwargs) + if name == 'MessageComponent': + cls.__message_component__ = new_cls + + if not cls.__message_component__: + return new_cls + + for base in bases: + if issubclass(base, cls.__message_component__): + # 获取字段名 + if hasattr(new_cls, '__fields__'): + # 忽略 type 字段 + new_cls.__parameter_names__ = list(new_cls.__fields__)[1:] + else: + new_cls.__parameter_names__ = [] + break + + return new_cls + + +class MessageComponent(PlatformIndexedModel, metaclass=MessageComponentMetaclass): + """消息组件。""" + type: str + """消息组件类型。""" + def __str__(self): + return '' + + def __repr__(self): + return self.__class__.__name__ + '(' + ', '.join( + ( + f'{k}={repr(v)}' + for k, v in self.__dict__.items() if k != 'type' and v + ) + ) + ')' + + def __init__(self, *args, **kwargs): + # 解析参数列表,将位置参数转化为具名参数 + parameter_names = self.__parameter_names__ + if len(args) > len(parameter_names): + raise TypeError( + f'`{self.type}`需要{len(parameter_names)}个参数,但传入了{len(args)}个。' + ) + for name, value in zip(parameter_names, args): + if name in kwargs: + raise TypeError(f'在 `{self.type}` 中,具名参数 `{name}` 与位置参数重复。') + kwargs[name] = value + + super().__init__(**kwargs) + + +TMessageComponent = typing.TypeVar('TMessageComponent', bound=MessageComponent) + + +class MessageChain(PlatformBaseModel): + """消息链。 + + 一个构造消息链的例子: + ```py + message_chain = MessageChain([ + AtAll(), + Plain("Hello World!"), + ]) + ``` + + `Plain` 可以省略。 + ```py + message_chain = MessageChain([ + AtAll(), + "Hello World!", + ]) + ``` + + 在调用 API 时,参数中需要 MessageChain 的,也可以使用 `List[MessageComponent]` 代替。 + 例如,以下两种写法是等价的: + ```py + await bot.send_friend_message(12345678, [ + Plain("Hello World!") + ]) + ``` + ```py + await bot.send_friend_message(12345678, MessageChain([ + Plain("Hello World!") + ])) + ``` + + 可以使用 `in` 运算检查消息链中: + 1. 是否有某个消息组件。 + 2. 是否有某个类型的消息组件。 + + ```py + if AtAll in message_chain: + print('AtAll') + + if At(bot.qq) in message_chain: + print('At Me') + ``` + + 消息链对索引操作进行了增强。以消息组件类型为索引,获取消息链中的全部该类型的消息组件。 + ```py + plain_list = message_chain[Plain] + '[Plain("Hello World!")]' + ``` + + 可以用加号连接两个消息链。 + ```py + MessageChain(['Hello World!']) + MessageChain(['Goodbye World!']) + # 返回 MessageChain([Plain("Hello World!"), Plain("Goodbye World!")]) + ``` + + """ + __root__: typing.List[MessageComponent] + + @staticmethod + def _parse_message_chain(msg_chain: typing.Iterable): + result = [] + for msg in msg_chain: + if isinstance(msg, dict): + result.append(MessageComponent.parse_subtype(msg)) + elif isinstance(msg, MessageComponent): + result.append(msg) + elif isinstance(msg, str): + result.append(Plain(msg)) + else: + raise TypeError( + f"消息链中元素需为 dict 或 str 或 MessageComponent,当前类型:{type(msg)}" + ) + return result + + @pydantic.validator('__root__', always=True, pre=True) + def _parse_component(cls, msg_chain): + if isinstance(msg_chain, (str, MessageComponent)): + msg_chain = [msg_chain] + if not msg_chain: + msg_chain = [] + return cls._parse_message_chain(msg_chain) + + @classmethod + def parse_obj(cls, msg_chain: typing.Iterable): + """通过列表形式的消息链,构造对应的 `MessageChain` 对象。 + + Args: + msg_chain: 列表形式的消息链。 + """ + result = cls._parse_message_chain(msg_chain) + return cls(__root__=result) + + def __init__(self, __root__: typing.Iterable[MessageComponent] = None): + super().__init__(__root__=__root__) + + def __str__(self): + return "".join(str(component) for component in self.__root__) + + def __repr__(self): + return f'{self.__class__.__name__}({self.__root__!r})' + + def __iter__(self): + yield from self.__root__ + + def get_first(self, + t: typing.Type[TMessageComponent]) -> typing.Optional[TMessageComponent]: + """获取消息链中第一个符合类型的消息组件。""" + for component in self: + if isinstance(component, t): + return component + return None + + @typing.overload + def __getitem__(self, index: int) -> MessageComponent: + ... + + @typing.overload + def __getitem__(self, index: slice) -> typing.List[MessageComponent]: + ... + + @typing.overload + def __getitem__(self, + index: typing.Type[TMessageComponent]) -> typing.List[TMessageComponent]: + ... + + @typing.overload + def __getitem__( + self, index: typing.Tuple[typing.Type[TMessageComponent], int] + ) -> typing.List[TMessageComponent]: + ... + + def __getitem__( + self, index: typing.Union[int, slice, typing.Type[TMessageComponent], + typing.Tuple[typing.Type[TMessageComponent], int]] + ) -> typing.Union[MessageComponent, typing.List[MessageComponent], + typing.List[TMessageComponent]]: + return self.get(index) + + def __setitem__( + self, key: typing.Union[int, slice], + value: typing.Union[MessageComponent, str, typing.Iterable[typing.Union[MessageComponent, + str]]] + ): + if isinstance(value, str): + value = Plain(value) + if isinstance(value, typing.Iterable): + value = (Plain(c) if isinstance(c, str) else c for c in value) + self.__root__[key] = value # type: ignore + + def __delitem__(self, key: typing.Union[int, slice]): + del self.__root__[key] + + def __reversed__(self) -> typing.Iterable[MessageComponent]: + return reversed(self.__root__) + + def has( + self, sub: typing.Union[MessageComponent, typing.Type[MessageComponent], + 'MessageChain', str] + ) -> bool: + """判断消息链中: + 1. 是否有某个消息组件。 + 2. 是否有某个类型的消息组件。 + + Args: + sub (`Union[MessageComponent, Type[MessageComponent], 'MessageChain', str]`): + 若为 `MessageComponent`,则判断该组件是否在消息链中。 + 若为 `Type[MessageComponent]`,则判断该组件类型是否在消息链中。 + + Returns: + bool: 是否找到。 + """ + if isinstance(sub, type): # 检测消息链中是否有某种类型的对象 + for i in self: + if type(i) is sub: + return True + return False + if isinstance(sub, MessageComponent): # 检查消息链中是否有某个组件 + for i in self: + if i == sub: + return True + return False + raise TypeError(f"类型不匹配,当前类型:{type(sub)}") + + def __contains__(self, sub) -> bool: + return self.has(sub) + + def __ge__(self, other): + return other in self + + def __len__(self) -> int: + return len(self.__root__) + + def __add__( + self, other: typing.Union['MessageChain', MessageComponent, str] + ) -> 'MessageChain': + if isinstance(other, MessageChain): + return self.__class__(self.__root__ + other.__root__) + if isinstance(other, str): + return self.__class__(self.__root__ + [Plain(other)]) + if isinstance(other, MessageComponent): + return self.__class__(self.__root__ + [other]) + return NotImplemented + + def __radd__(self, other: typing.Union[MessageComponent, str]) -> 'MessageChain': + if isinstance(other, MessageComponent): + return self.__class__([other] + self.__root__) + if isinstance(other, str): + return self.__class__( + [typing.cast(MessageComponent, Plain(other))] + self.__root__ + ) + return NotImplemented + + def __mul__(self, other: int): + if isinstance(other, int): + return self.__class__(self.__root__ * other) + return NotImplemented + + def __rmul__(self, other: int): + return self.__mul__(other) + + def __iadd__(self, other: typing.Iterable[typing.Union[MessageComponent, str]]): + self.extend(other) + + def __imul__(self, other: int): + if isinstance(other, int): + self.__root__ *= other + return NotImplemented + + def index( + self, + x: typing.Union[MessageComponent, typing.Type[MessageComponent]], + i: int = 0, + j: int = -1 + ) -> int: + """返回 x 在消息链中首次出现项的索引号(索引号在 i 或其后且在 j 之前)。 + + Args: + x (`Union[MessageComponent, Type[MessageComponent]]`): + 要查找的消息元素或消息元素类型。 + i: 从哪个位置开始查找。 + j: 查找到哪个位置结束。 + + Returns: + int: 如果找到,则返回索引号。 + + Raises: + ValueError: 没有找到。 + TypeError: 类型不匹配。 + """ + if isinstance(x, type): + l = len(self) + if i < 0: + i += l + if i < 0: + i = 0 + if j < 0: + j += l + if j > l: + j = l + for index in range(i, j): + if type(self[index]) is x: + return index + raise ValueError("消息链中不存在该类型的组件。") + if isinstance(x, MessageComponent): + return self.__root__.index(x, i, j) + raise TypeError(f"类型不匹配,当前类型:{type(x)}") + + def count(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]) -> int: + """返回消息链中 x 出现的次数。 + + Args: + x (`Union[MessageComponent, Type[MessageComponent]]`): + 要查找的消息元素或消息元素类型。 + + Returns: + int: 次数。 + """ + if isinstance(x, type): + return sum(1 for i in self if type(i) is x) + if isinstance(x, MessageComponent): + return self.__root__.count(x) + raise TypeError(f"类型不匹配,当前类型:{type(x)}") + + def extend(self, x: typing.Iterable[typing.Union[MessageComponent, str]]): + """将另一个消息链中的元素添加到消息链末尾。 + + Args: + x: 另一个消息链,也可为消息元素或字符串元素的序列。 + """ + self.__root__.extend(Plain(c) if isinstance(c, str) else c for c in x) + + def append(self, x: typing.Union[MessageComponent, str]): + """将一个消息元素或字符串元素添加到消息链末尾。 + + Args: + x: 消息元素或字符串元素。 + """ + self.__root__.append(Plain(x) if isinstance(x, str) else x) + + def insert(self, i: int, x: typing.Union[MessageComponent, str]): + """将一个消息元素或字符串添加到消息链中指定位置。 + + Args: + i: 插入位置。 + x: 消息元素或字符串元素。 + """ + self.__root__.insert(i, Plain(x) if isinstance(x, str) else x) + + def pop(self, i: int = -1) -> MessageComponent: + """从消息链中移除并返回指定位置的元素。 + + Args: + i: 移除位置。默认为末尾。 + + Returns: + MessageComponent: 移除的元素。 + """ + return self.__root__.pop(i) + + def remove(self, x: typing.Union[MessageComponent, typing.Type[MessageComponent]]): + """从消息链中移除指定元素或指定类型的一个元素。 + + Args: + x: 指定的元素或元素类型。 + """ + if isinstance(x, type): + self.pop(self.index(x)) + if isinstance(x, MessageComponent): + self.__root__.remove(x) + + def exclude( + self, + x: typing.Union[MessageComponent, typing.Type[MessageComponent]], + count: int = -1 + ) -> 'MessageChain': + """返回移除指定元素或指定类型的元素后剩余的消息链。 + + Args: + x: 指定的元素或元素类型。 + count: 至多移除的数量。默认为全部移除。 + + Returns: + MessageChain: 剩余的消息链。 + """ + def _exclude(): + nonlocal count + x_is_type = isinstance(x, type) + for c in self: + if count > 0 and ((x_is_type and type(c) is x) or c == x): + count -= 1 + continue + yield c + + return self.__class__(_exclude()) + + def reverse(self): + """将消息链原地翻转。""" + self.__root__.reverse() + + @classmethod + def join(cls, *args: typing.Iterable[typing.Union[str, MessageComponent]]): + return cls( + Plain(c) if isinstance(c, str) else c + for c in itertools.chain(*args) + ) + + @property + def source(self) -> typing.Optional['Source']: + """获取消息链中的 `Source` 对象。""" + return self.get_first(Source) + + @property + def message_id(self) -> int: + """获取消息链的 message_id,若无法获取,返回 -1。""" + source = self.source + return source.id if source else -1 + + +TMessage = typing.Union[MessageChain, typing.Iterable[typing.Union[MessageComponent, str]], + MessageComponent, str] +"""可以转化为 MessageChain 的类型。""" + + +class Source(MessageComponent): + """源。包含消息的基本信息。""" + type: str = "Source" + """消息组件类型。""" + id: int + """消息的识别号,用于引用回复(Source 类型永远为 MessageChain 的第一个元素)。""" + time: datetime + """消息时间。""" + + +class Plain(MessageComponent): + """纯文本。""" + type: str = "Plain" + """消息组件类型。""" + text: str + """文字消息。""" + def __str__(self): + return self.text + + def __repr__(self): + return f'Plain({self.text!r})' + + +class Quote(MessageComponent): + """引用。""" + type: str = "Quote" + """消息组件类型。""" + id: typing.Optional[int] = None + """被引用回复的原消息的 message_id。""" + group_id: typing.Optional[int] = None + """被引用回复的原消息所接收的群号,当为好友消息时为0。""" + sender_id: typing.Optional[int] = None + """被引用回复的原消息的发送者的QQ号。""" + target_id: typing.Optional[int] = None + """被引用回复的原消息的接收者者的QQ号(或群号)。""" + origin: MessageChain + """被引用回复的原消息的消息链对象。""" + + @pydantic.validator("origin", always=True, pre=True) + def origin_formater(cls, v): + return MessageChain.parse_obj(v) + + +class At(MessageComponent): + """At某人。""" + type: str = "At" + """消息组件类型。""" + target: int + """群员 QQ 号。""" + display: typing.Optional[str] = None + """At时显示的文字,发送消息时无效,自动使用群名片。""" + def __eq__(self, other): + return isinstance(other, At) and self.target == other.target + + def __str__(self): + return f"@{self.display or self.target}" + + +class AtAll(MessageComponent): + """At全体。""" + type: str = "AtAll" + """消息组件类型。""" + def __str__(self): + return "@全体成员" + + +class Image(MessageComponent): + """图片。""" + type: str = "Image" + """消息组件类型。""" + image_id: typing.Optional[str] = None + """图片的 image_id,群图片与好友图片格式不同。不为空时将忽略 url 属性。""" + url: typing.Optional[pydantic.HttpUrl] = None + """图片的 URL,发送时可作网络图片的链接;接收时为腾讯图片服务器的链接,可用于图片下载。""" + path: typing.Union[str, Path, None] = None + """图片的路径,发送本地图片。""" + base64: typing.Optional[str] = None + """图片的 Base64 编码。""" + def __eq__(self, other): + return isinstance( + other, Image + ) and self.type == other.type and self.uuid == other.uuid + + def __str__(self): + return '[图片]' + + @pydantic.validator('path') + def validate_path(cls, path: typing.Union[str, Path, None]): + """修复 path 参数的行为,使之相对于 QChatGPT 的启动路径。""" + if path: + try: + return str(Path(path).resolve(strict=True)) + except FileNotFoundError: + raise ValueError(f"无效路径:{path}") + else: + return path + + @property + def uuid(self): + image_id = self.image_id + if image_id[0] == '{': # 群图片 + image_id = image_id[1:37] + elif image_id[0] == '/': # 好友图片 + image_id = image_id[1:] + return image_id + + async def download( + self, + filename: typing.Union[str, Path, None] = None, + directory: typing.Union[str, Path, None] = None, + determine_type: bool = True + ): + """下载图片到本地。 + + Args: + filename: 下载到本地的文件路径。与 `directory` 二选一。 + directory: 下载到本地的文件夹路径。与 `filename` 二选一。 + determine_type: 是否自动根据图片类型确定拓展名,默认为 True。 + """ + if not self.url: + logger.warning(f'图片 `{self.uuid}` 无 url 参数,下载失败。') + return + + import httpx + async with httpx.AsyncClient() as client: + response = await client.get(self.url) + response.raise_for_status() + content = response.content + + if filename: + path = Path(filename) + if determine_type: + import imghdr + path = path.with_suffix( + '.' + str(imghdr.what(None, content)) + ) + path.parent.mkdir(parents=True, exist_ok=True) + elif directory: + import imghdr + path = Path(directory) + path.mkdir(parents=True, exist_ok=True) + path = path / f'{self.uuid}.{imghdr.what(None, content)}' + else: + raise ValueError("请指定文件路径或文件夹路径!") + + import aiofiles + async with aiofiles.open(path, 'wb') as f: + await f.write(content) + + return path + + @classmethod + async def from_local( + cls, + filename: typing.Union[str, Path, None] = None, + content: typing.Optional[bytes] = None, + ) -> "Image": + """从本地文件路径加载图片,以 base64 的形式传递。 + + Args: + filename: 从本地文件路径加载图片,与 `content` 二选一。 + content: 从本地文件内容加载图片,与 `filename` 二选一。 + + Returns: + Image: 图片对象。 + """ + if content: + pass + elif filename: + path = Path(filename) + import aiofiles + async with aiofiles.open(path, 'rb') as f: + content = await f.read() + else: + raise ValueError("请指定图片路径或图片内容!") + import base64 + img = cls(base64=base64.b64encode(content).decode()) + return img + + @classmethod + def from_unsafe_path(cls, path: typing.Union[str, Path]) -> "Image": + """从不安全的路径加载图片。 + + Args: + path: 从不安全的路径加载图片。 + + Returns: + Image: 图片对象。 + """ + return cls.construct(path=str(path)) + + +class Unknown(MessageComponent): + """未知。""" + type: str = "Unknown" + """消息组件类型。""" + text: str + """文本。""" + + +class Voice(MessageComponent): + """语音。""" + type: str = "Voice" + """消息组件类型。""" + voice_id: typing.Optional[str] = None + """语音的 voice_id,不为空时将忽略 url 属性。""" + url: typing.Optional[str] = None + """语音的 URL,发送时可作网络语音的链接;接收时为腾讯语音服务器的链接,可用于语音下载。""" + path: typing.Optional[str] = None + """语音的路径,发送本地语音。""" + base64: typing.Optional[str] = None + """语音的 Base64 编码。""" + length: typing.Optional[int] = None + """语音的长度,单位为秒。""" + @pydantic.validator('path') + def validate_path(cls, path: typing.Optional[str]): + """修复 path 参数的行为,使之相对于 QChatGPT 的启动路径。""" + if path: + try: + return str(Path(path).resolve(strict=True)) + except FileNotFoundError: + raise ValueError(f"无效路径:{path}") + else: + return path + + def __str__(self): + return '[语音]' + + async def download( + self, + filename: typing.Union[str, Path, None] = None, + directory: typing.Union[str, Path, None] = None + ): + """下载语音到本地。 + + 语音采用 silk v3 格式,silk 格式的编码解码请使用 [graiax-silkcoder](https://pypi.org/project/graiax-silkcoder/)。 + + Args: + filename: 下载到本地的文件路径。与 `directory` 二选一。 + directory: 下载到本地的文件夹路径。与 `filename` 二选一。 + """ + if not self.url: + logger.warning(f'语音 `{self.voice_id}` 无 url 参数,下载失败。') + return + + import httpx + async with httpx.AsyncClient() as client: + response = await client.get(self.url) + response.raise_for_status() + content = response.content + + if filename: + path = Path(filename) + path.parent.mkdir(parents=True, exist_ok=True) + elif directory: + path = Path(directory) + path.mkdir(parents=True, exist_ok=True) + path = path / f'{self.voice_id}.silk' + else: + raise ValueError("请指定文件路径或文件夹路径!") + + import aiofiles + async with aiofiles.open(path, 'wb') as f: + await f.write(content) + + @classmethod + async def from_local( + cls, + filename: typing.Union[str, Path, None] = None, + content: typing.Optional[bytes] = None, + ) -> "Voice": + """从本地文件路径加载语音,以 base64 的形式传递。 + + Args: + filename: 从本地文件路径加载语音,与 `content` 二选一。 + content: 从本地文件内容加载语音,与 `filename` 二选一。 + """ + if content: + pass + if filename: + path = Path(filename) + import aiofiles + async with aiofiles.open(path, 'rb') as f: + content = await f.read() + else: + raise ValueError("请指定语音路径或语音内容!") + import base64 + img = cls(base64=base64.b64encode(content).decode()) + return img + + +class ForwardMessageNode(pydantic.BaseModel): + """合并转发中的一条消息。""" + sender_id: typing.Optional[int] = None + """发送人QQ号。""" + sender_name: typing.Optional[str] = None + """显示名称。""" + message_chain: typing.Optional[MessageChain] = None + """消息内容。""" + message_id: typing.Optional[int] = None + """消息的 message_id,可以只使用此属性,从缓存中读取消息内容。""" + time: typing.Optional[datetime] = None + """发送时间。""" + @pydantic.validator('message_chain', check_fields=False) + def _validate_message_chain(cls, value: typing.Union[MessageChain, list]): + if isinstance(value, list): + return MessageChain.parse_obj(value) + return value + + @classmethod + def create( + cls, sender: typing.Union[platform_entities.Friend, platform_entities.GroupMember], message: MessageChain + ) -> 'ForwardMessageNode': + """从消息链生成转发消息。 + + Args: + sender: 发送人。 + message: 消息内容。 + + Returns: + ForwardMessageNode: 生成的一条消息。 + """ + return ForwardMessageNode( + sender_id=sender.id, + sender_name=sender.get_name(), + message_chain=message + ) + + +class Forward(MessageComponent): + """合并转发。""" + type: str = "Forward" + """消息组件类型。""" + node_list: typing.List[ForwardMessageNode] + """转发消息节点列表。""" + def __init__(self, *args, **kwargs): + if len(args) == 1: + self.node_list = args[0] + super().__init__(**kwargs) + super().__init__(*args, **kwargs) + + def __str__(self): + return '[聊天记录]' + + +class File(MessageComponent): + """文件。""" + type: str = "File" + """消息组件类型。""" + id: str + """文件识别 ID。""" + name: str + """文件名称。""" + size: int + """文件大小。""" + def __str__(self): + return f'[文件]{self.name}' + diff --git a/pkg/plugin/context.py b/pkg/plugin/context.py index 42cb6be..f6cc176 100644 --- a/pkg/plugin/context.py +++ b/pkg/plugin/context.py @@ -3,11 +3,11 @@ from __future__ import annotations import typing import abc import pydantic -import mirai from . import events from ..provider.tools import entities as tools_entities from ..core import app +from ..platform.types import message as platform_message def register( @@ -174,11 +174,11 @@ class EventContext: self.__return_value__[key] = [] self.__return_value__[key].append(ret) - async def reply(self, message_chain: mirai.MessageChain): + async def reply(self, message_chain: platform_message.MessageChain): """回复此次消息请求 Args: - message_chain (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链 + message_chain (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 """ await self.host.ap.platform_mgr.send( event=self.event.query.message_event, @@ -190,14 +190,14 @@ class EventContext: self, target_type: str, target_id: str, - message: mirai.MessageChain + message: platform_message.MessageChain ): """主动发送消息 Args: target_type (str): 目标类型,`person`或`group` target_id (str): 目标ID - message (mirai.MessageChain): YiriMirai库的消息链,若用户使用的不是 YiriMirai 适配器,程序也能自动转换为目标消息链 + message (platform.types.MessageChain): 源平台的消息链,若用户使用的不是源平台适配器,程序也能自动转换为目标平台消息链 """ await self.event.query.adapter.send_message( target_type=target_type, diff --git a/pkg/plugin/events.py b/pkg/plugin/events.py index b591976..013dd11 100644 --- a/pkg/plugin/events.py +++ b/pkg/plugin/events.py @@ -3,10 +3,10 @@ from __future__ import annotations import typing import pydantic -import mirai from ..core import entities as core_entities from ..provider import entities as llm_entities +from ..platform.types import message as platform_message class BaseEventModel(pydantic.BaseModel): @@ -31,7 +31,7 @@ class PersonMessageReceived(BaseEventModel): sender_id: int """发送者ID(QQ号)""" - message_chain: mirai.MessageChain + message_chain: platform_message.MessageChain class GroupMessageReceived(BaseEventModel): @@ -43,7 +43,7 @@ class GroupMessageReceived(BaseEventModel): sender_id: int - message_chain: mirai.MessageChain + message_chain: platform_message.MessageChain class PersonNormalMessageReceived(BaseEventModel): diff --git a/pkg/plugin/manager.py b/pkg/plugin/manager.py index 4e0784a..b1241ea 100644 --- a/pkg/plugin/manager.py +++ b/pkg/plugin/manager.py @@ -48,6 +48,8 @@ class PluginManager: # 按优先级倒序 self.plugins.sort(key=lambda x: x.priority, reverse=True) + self.ap.logger.debug(f'优先级排序后的插件列表 {self.plugins}') + async def initialize_plugins(self): for plugin in self.plugins: try: diff --git a/pkg/plugin/setting.py b/pkg/plugin/setting.py index 7e715af..bd50603 100644 --- a/pkg/plugin/setting.py +++ b/pkg/plugin/setting.py @@ -45,6 +45,7 @@ class SettingManager: for plugin_container in plugin_containers: if plugin_container.plugin_name == value['name']: plugin_container.set_from_setting_dict(value) + break self.settings.data = { 'plugins': [ diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 5000072..803613a 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -4,7 +4,8 @@ import typing import enum import pydantic -import mirai + +from ..platform.types import message as platform_message class FunctionCall(pydantic.BaseModel): @@ -73,14 +74,14 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return str(self.role) + ": " + str(self.get_content_mirai_message_chain()) + return str(self.role) + ": " + str(self.get_content_platform_message_chain()) elif self.tool_calls is not None: return f'调用工具: {self.tool_calls[0].id}' else: return '未知消息' - def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None: - """将内容转换为 Mirai MessageChain 对象 + def get_content_platform_message_chain(self, prefix_text: str="") -> platform_message.MessageChain | None: + """将内容转换为平台消息 MessageChain 对象 Args: prefix_text (str): 首个文字组件的前缀文本 @@ -89,15 +90,15 @@ class Message(pydantic.BaseModel): if self.content is None: return None elif isinstance(self.content, str): - return mirai.MessageChain([mirai.Plain(prefix_text+self.content)]) + return platform_message.MessageChain([platform_message.Plain(prefix_text+self.content)]) elif isinstance(self.content, list): mc = [] for ce in self.content: if ce.type == 'text': - mc.append(mirai.Plain(ce.text)) + mc.append(platform_message.Plain(ce.text)) elif ce.type == 'image_url': if ce.image_url.url.startswith("http"): - mc.append(mirai.Image(url=ce.image_url.url)) + mc.append(platform_message.Image(url=ce.image_url.url)) else: # base64 b64_str = ce.image_url.url @@ -105,15 +106,15 @@ class Message(pydantic.BaseModel): if b64_str.startswith("data:"): b64_str = b64_str.split(",")[1] - mc.append(mirai.Image(base64=b64_str)) - + mc.append(platform_message.Image(base64=b64_str)) + # 找第一个文字组件 if prefix_text: for i, c in enumerate(mc): - if isinstance(c, mirai.Plain): - mc[i] = mirai.Plain(prefix_text+c.text) + if isinstance(c, platform_message.Plain): + mc[i] = platform_message.Plain(prefix_text+c.text) break else: - mc.insert(0, mirai.Plain(prefix_text)) + mc.insert(0, platform_message.Plain(prefix_text)) - return mirai.MessageChain(mc) + return platform_message.MessageChain(mc) diff --git a/requirements.txt b/requirements.txt index 1b554d2..7bf257e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,6 @@ requests openai>1.0.0 anthropic colorlog~=6.6.0 -yiri-mirai-rc aiocqhttp qq-botpy nakuru-project-idk