perf: resp_message_chain 改为 list 类型 (#770)

This commit is contained in:
RockChinQ 2024-05-14 23:08:49 +08:00
parent 269e561497
commit 8807f02f36
6 changed files with 14 additions and 20 deletions

View File

@ -70,7 +70,7 @@ class Query(pydantic.BaseModel):
resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
"""回复消息链从resp_messages包装而得"""
class Config:

View File

@ -62,15 +62,15 @@ class LongTextProcessStage(stage.PipelineStage):
# 检查是否包含非 Plain 组件
contains_non_plain = False
for msg in query.resp_message_chain:
for msg in query.resp_message_chain[-1]:
if not isinstance(msg, Plain):
contains_non_plain = True
break
if contains_non_plain:
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@ -43,7 +43,7 @@ class QueryPool:
message_event=message_event,
message_chain=message_chain,
resp_messages=[],
resp_message_chain=None,
resp_message_chain=[],
adapter=adapter
)
self.queries.append(query)

View File

@ -80,9 +80,6 @@ class CommandHandler(handler.MessageHandler):
session=session
):
if ret.error is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(str(ret.error))
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',
@ -97,9 +94,6 @@ class CommandHandler(handler.MessageHandler):
new_query=query
)
elif ret.text is not None:
# query.resp_message_chain = mirai.MessageChain([
# mirai.Plain(ret.text)
# ])
query.resp_messages.append(
llm_entities.Message(
role='command',

View File

@ -31,7 +31,7 @@ class SendResponseBackStage(stage.PipelineStage):
await self.ap.platform_mgr.send(
query.message_event,
query.resp_message_chain,
query.resp_message_chain[-1],
adapter=query.adapter
)

View File

@ -34,7 +34,7 @@ class ResponseWrapper(stage.PipelineStage):
"""
if query.resp_messages[-1].role == 'command':
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@ -42,9 +42,9 @@ class ResponseWrapper(stage.PipelineStage):
)
elif query.resp_messages[-1].role == 'plugin':
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
else:
query.resp_message_chain = query.resp_messages[-1].content
query.resp_message_chain.append(query.resp_messages[-1].content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@ -83,11 +83,11 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
@ -100,7 +100,7 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
if self.ap.platform_cfg.data['track-function-calls']:
@ -126,11 +126,11 @@ class ResponseWrapper(stage.PipelineStage):
else:
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
else:
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,