mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
feat: message.content 支持 mirai.MessageChain 对象 (#741)
This commit is contained in:
parent
2e9229a6ad
commit
8b003739f1
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -34,4 +34,4 @@ bard.json
|
||||||
res/instance_id.json
|
res/instance_id.json
|
||||||
.DS_Store
|
.DS_Store
|
||||||
/data
|
/data
|
||||||
botpy.log
|
botpy.log*
|
|
@ -135,9 +135,17 @@ class ContentFilterStage(stage.PipelineStage):
|
||||||
query
|
query
|
||||||
)
|
)
|
||||||
elif stage_inst_name == 'PostContentFilterStage':
|
elif stage_inst_name == 'PostContentFilterStage':
|
||||||
|
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
||||||
|
if isinstance(query.resp_messages[-1].content, str):
|
||||||
return await self._post_process(
|
return await self._post_process(
|
||||||
query.resp_messages[-1].content,
|
query.resp_messages[-1].content,
|
||||||
query
|
query
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。")
|
||||||
|
return entities.StageProcessResult(
|
||||||
|
result_type=entities.ResultType.CONTINUE,
|
||||||
|
new_query=query
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
||||||
|
|
|
@ -56,8 +56,19 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||||
await self.strategy_impl.initialize()
|
await self.strategy_impl.initialize()
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
# 检查是否包含非 Plain 组件
|
||||||
|
contains_non_plain = False
|
||||||
|
|
||||||
|
for msg in query.resp_message_chain:
|
||||||
|
if not isinstance(msg, Plain):
|
||||||
|
contains_non_plain = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if contains_non_plain:
|
||||||
|
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
|
||||||
|
elif len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||||
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
|
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
|
||||||
|
|
||||||
return entities.StageProcessResult(
|
return entities.StageProcessResult(
|
||||||
result_type=entities.ResultType.CONTINUE,
|
result_type=entities.ResultType.CONTINUE,
|
||||||
new_query=query
|
new_query=query
|
||||||
|
|
|
@ -44,7 +44,7 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||||
query.resp_messages.append(
|
query.resp_messages.append(
|
||||||
llm_entities.Message(
|
llm_entities.Message(
|
||||||
role='plugin',
|
role='plugin',
|
||||||
content=str(mc),
|
content=mc,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -34,7 +34,10 @@ class ResponseWrapper(stage.PipelineStage):
|
||||||
new_query=query
|
new_query=query
|
||||||
)
|
)
|
||||||
elif query.resp_messages[-1].role == 'plugin':
|
elif query.resp_messages[-1].role == 'plugin':
|
||||||
|
if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
|
||||||
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
|
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
|
||||||
|
else:
|
||||||
|
query.resp_message_chain = query.resp_messages[-1].content
|
||||||
|
|
||||||
yield entities.StageProcessResult(
|
yield entities.StageProcessResult(
|
||||||
result_type=entities.ResultType.CONTINUE,
|
result_type=entities.ResultType.CONTINUE,
|
||||||
|
|
|
@ -4,6 +4,8 @@ import typing
|
||||||
import enum
|
import enum
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
|
import mirai
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(pydantic.BaseModel):
|
class FunctionCall(pydantic.BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
@ -28,7 +30,7 @@ class Message(pydantic.BaseModel):
|
||||||
name: typing.Optional[str] = None
|
name: typing.Optional[str] = None
|
||||||
"""名称,仅函数调用返回时设置"""
|
"""名称,仅函数调用返回时设置"""
|
||||||
|
|
||||||
content: typing.Optional[str] = None
|
content: typing.Optional[str] | typing.Optional[mirai.MessageChain] = None
|
||||||
"""内容"""
|
"""内容"""
|
||||||
|
|
||||||
function_call: typing.Optional[FunctionCall] = None
|
function_call: typing.Optional[FunctionCall] = None
|
||||||
|
@ -41,7 +43,7 @@ class Message(pydantic.BaseModel):
|
||||||
|
|
||||||
def readable_str(self) -> str:
|
def readable_str(self) -> str:
|
||||||
if self.content is not None:
|
if self.content is not None:
|
||||||
return self.content
|
return str(self.content)
|
||||||
elif self.function_call is not None:
|
elif self.function_call is not None:
|
||||||
return f'{self.function_call.name}({self.function_call.arguments})'
|
return f'{self.function_call.name}({self.function_call.arguments})'
|
||||||
elif self.tool_calls is not None:
|
elif self.tool_calls is not None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user