feat: message.content 支持 mirai.MessageChain 对象 (#741)

This commit is contained in:
RockChinQ 2024-03-31 14:38:15 +08:00
parent 2e9229a6ad
commit 8b003739f1
6 changed files with 34 additions and 10 deletions

2
.gitignore vendored
View File

@ -34,4 +34,4 @@ bard.json
res/instance_id.json
.DS_Store
/data
botpy.log
botpy.log*

View File

@ -135,9 +135,17 @@ class ContentFilterStage(stage.PipelineStage):
query
)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1].content, str):
return await self._post_process(
query.resp_messages[-1].content,
query
)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')

View File

@ -56,8 +56,19 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
# 检查是否包含非 Plain 组件
contains_non_plain = False
for msg in query.resp_message_chain:
if not isinstance(msg, Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@ -44,7 +44,7 @@ class ChatMessageHandler(handler.MessageHandler):
query.resp_messages.append(
llm_entities.Message(
role='plugin',
content=str(mc),
content=mc,
)
)

View File

@ -34,7 +34,10 @@ class ResponseWrapper(stage.PipelineStage):
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
else:
query.resp_message_chain = query.resp_messages[-1].content
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@ -4,6 +4,8 @@ import typing
import enum
import pydantic
import mirai
class FunctionCall(pydantic.BaseModel):
name: str
@ -28,7 +30,7 @@ class Message(pydantic.BaseModel):
name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""
content: typing.Optional[str] = None
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
"""内容"""
function_call: typing.Optional[FunctionCall] = None
@ -41,7 +43,7 @@ class Message(pydantic.BaseModel):
def readable_str(self) -> str:
if self.content is not None:
return self.content
return str(self.content)
elif self.function_call is not None:
return f'{self.function_call.name}({self.function_call.arguments})'
elif self.tool_calls is not None: