From 8b003739f133f1d05aa42edd15296b98c1c44426 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 31 Mar 2024 14:38:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20message.content=20=E6=94=AF=E6=8C=81=20?= =?UTF-8?q?mirai.MessageChain=20=E5=AF=B9=E8=B1=A1=20(#741)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- pkg/pipeline/cntfilter/cntfilter.py | 16 ++++++++++++---- pkg/pipeline/longtext/longtext.py | 13 ++++++++++++- pkg/pipeline/process/handlers/chat.py | 2 +- pkg/pipeline/wrapper/wrapper.py | 5 ++++- pkg/provider/entities.py | 6 ++++-- 6 files changed, 34 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 2af326f..56282ca 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,4 @@ bard.json res/instance_id.json .DS_Store /data -botpy.log \ No newline at end of file +botpy.log* \ No newline at end of file diff --git a/pkg/pipeline/cntfilter/cntfilter.py b/pkg/pipeline/cntfilter/cntfilter.py index 1b726d9..21b6c25 100644 --- a/pkg/pipeline/cntfilter/cntfilter.py +++ b/pkg/pipeline/cntfilter/cntfilter.py @@ -135,9 +135,17 @@ class ContentFilterStage(stage.PipelineStage): query ) elif stage_inst_name == 'PostContentFilterStage': - return await self._post_process( - query.resp_messages[-1].content, - query - ) + # 仅处理 query.resp_messages[-1].content 是 str 的情况 + if isinstance(query.resp_messages[-1].content, str): + return await self._post_process( + query.resp_messages[-1].content, + query + ) + else: + self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。") + return entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) else: raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}') diff --git a/pkg/pipeline/longtext/longtext.py b/pkg/pipeline/longtext/longtext.py index 2095845..28c2814 100644 --- a/pkg/pipeline/longtext/longtext.py +++ b/pkg/pipeline/longtext/longtext.py @@ -56,8 +56,19 @@ class LongTextProcessStage(stage.PipelineStage): await self.strategy_impl.initialize() async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: - if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: + # 检查是否包含非 Plain 组件 + contains_non_plain = False + + for msg in query.resp_message_chain: + if not isinstance(msg, Plain): + contains_non_plain = True + break + + if contains_non_plain: + self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。") + elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']: query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query)) + return entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, new_query=query diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 58b7a83..f38ee34 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -44,7 +44,7 @@ class ChatMessageHandler(handler.MessageHandler): query.resp_messages.append( llm_entities.Message( role='plugin', - content=str(mc), + content=mc, ) ) diff --git a/pkg/pipeline/wrapper/wrapper.py b/pkg/pipeline/wrapper/wrapper.py index 80277a0..a500d7c 100644 --- a/pkg/pipeline/wrapper/wrapper.py +++ b/pkg/pipeline/wrapper/wrapper.py @@ -34,7 +34,10 @@ class ResponseWrapper(stage.PipelineStage): new_query=query ) elif query.resp_messages[-1].role == 'plugin': - query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content) + if not isinstance(query.resp_messages[-1].content, mirai.MessageChain): + query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content) + else: + query.resp_message_chain = query.resp_messages[-1].content yield entities.StageProcessResult( result_type=entities.ResultType.CONTINUE, diff --git a/pkg/provider/entities.py b/pkg/provider/entities.py index 8c0c76b..a30d4e3 100644 --- a/pkg/provider/entities.py +++ b/pkg/provider/entities.py @@ -4,6 +4,8 @@ import typing import enum import pydantic +import mirai + class FunctionCall(pydantic.BaseModel): name: str @@ -28,7 +30,7 @@ class Message(pydantic.BaseModel): name: typing.Optional[str] = None """名称,仅函数调用返回时设置""" - content: typing.Optional[str] = None + content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None """内容""" function_call: typing.Optional[FunctionCall] = None @@ -41,7 +43,7 @@ class Message(pydantic.BaseModel): def readable_str(self) -> str: if self.content is not None: - return self.content + return str(self.content) elif self.function_call is not None: return f'{self.function_call.name}({self.function_call.arguments})' elif self.tool_calls is not None: