mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
feat: 消息截断器
This commit is contained in:
parent
21fe5822f9
commit
c8eb2e3376
24
pkg/config/migrations/m009_msg_truncator_cfg.py
Normal file
24
pkg/config/migrations/m009_msg_truncator_cfg.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("msg-truncator-cfg-migration", 9)
|
||||
class MsgTruncatorConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return 'msg-truncate' not in self.ap.pipeline_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
self.ap.pipeline_cfg.data['msg-truncate'] = {
|
||||
'method': 'round',
|
||||
'round': {
|
||||
'max-round': 10
|
||||
}
|
||||
}
|
||||
|
||||
await self.ap.pipeline_cfg.dump_config()
|
|
@ -5,7 +5,7 @@ import importlib
|
|||
from .. import stage, app
|
||||
from ...config import migration
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate
|
||||
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
||||
|
||||
|
||||
@stage.stage_class("MigrationStage")
|
||||
|
|
0
pkg/pipeline/msgtrun/__init__.py
Normal file
0
pkg/pipeline/msgtrun/__init__.py
Normal file
35
pkg/pipeline/msgtrun/msgtrun.py
Normal file
35
pkg/pipeline/msgtrun/msgtrun.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from . import truncator
|
||||
from .truncators import round
|
||||
|
||||
|
||||
@stage.stage_class("ConversationMessageTruncator")
|
||||
class ConversationMessageTruncator(stage.PipelineStage):
|
||||
"""会话消息截断器
|
||||
|
||||
用于截断会话消息链,以适应平台消息长度限制。
|
||||
"""
|
||||
trun: truncator.Truncator
|
||||
|
||||
async def initialize(self):
|
||||
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
|
||||
|
||||
for trun in truncator.preregistered_truncators:
|
||||
if trun.name == use_method:
|
||||
self.trun = trun(self.ap)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未知的截断器: {use_method}")
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
56
pkg/pipeline/msgtrun/truncator.py
Normal file
56
pkg/pipeline/msgtrun/truncator.py
Normal file
|
@ -0,0 +1,56 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ...core import entities as core_entities, app
|
||||
|
||||
|
||||
preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||
|
||||
|
||||
def truncator_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
|
||||
"""截断器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 截断器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
|
||||
assert issubclass(cls, Truncator)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_truncators.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Truncator(abc.ABC):
|
||||
"""消息截断器基类
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""截断
|
||||
|
||||
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
||||
请勿操作其他字段。
|
||||
"""
|
||||
pass
|
0
pkg/pipeline/msgtrun/truncators/__init__.py
Normal file
0
pkg/pipeline/msgtrun/truncators/__init__.py
Normal file
32
pkg/pipeline/msgtrun/truncators/round.py
Normal file
32
pkg/pipeline/msgtrun/truncators/round.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from .. import truncator
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@truncator.truncator_class("round")
|
||||
class RoundTruncator(truncator.Truncator):
|
||||
"""前文回合数阶段器
|
||||
"""
|
||||
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""截断
|
||||
"""
|
||||
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
|
||||
|
||||
temp_messages = []
|
||||
|
||||
current_round = 0
|
||||
|
||||
# 从后往前遍历
|
||||
for msg in query.messages[::-1]:
|
||||
if current_round < max_round:
|
||||
temp_messages.append(msg)
|
||||
if msg.role == 'user':
|
||||
current_round += 1
|
||||
else:
|
||||
break
|
||||
|
||||
query.messages = temp_messages[::-1]
|
||||
|
||||
return query
|
|
@ -13,6 +13,7 @@ from .respback import respback
|
|||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
from .msgtrun import msgtrun
|
||||
|
||||
|
||||
# 请求处理阶段顺序
|
||||
|
@ -21,6 +22,7 @@ stage_order = [
|
|||
"BanSessionCheckStage", # 封禁会话检查
|
||||
"PreContentFilterStage", # 内容过滤前置阶段
|
||||
"PreProcessor", # 预处理器
|
||||
"ConversationMessageTruncator", # 会话消息截断器
|
||||
"RequireRateLimitOccupancy", # 请求速率限制占用
|
||||
"MessageProcessor", # 处理器
|
||||
"ReleaseRateLimitOccupancy", # 释放速率限制占用
|
||||
|
|
|
@ -34,5 +34,11 @@
|
|||
"limit": 60
|
||||
}
|
||||
}
|
||||
},
|
||||
"msg-truncate": {
|
||||
"method": "round",
|
||||
"round": {
|
||||
"max-round": 10
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user