QChatGPT/pkg/pipeline/stagemgr.py

71 lines
1.6 KiB
Python
Raw Normal View History

from __future__ import annotations
import pydantic
from ..core import app
from . import stage
from .resprule import resprule
from .bansess import bansess
from .cntfilter import cntfilter
from .process import process
from .longtext import longtext
from .respback import respback
from .wrapper import wrapper
2024-02-01 16:35:00 +08:00
from .preproc import preproc
2024-02-01 18:38:20 +08:00
from .ratelimit import ratelimit
stage_order = [
"GroupRespondRuleCheckStage",
"BanSessionCheckStage",
"PreContentFilterStage",
2024-02-01 16:35:00 +08:00
"PreProcessor",
2024-02-01 18:38:20 +08:00
"RequireRateLimitOccupancy",
"MessageProcessor",
2024-02-01 18:38:20 +08:00
"ReleaseRateLimitOccupancy",
"PostContentFilterStage",
"ResponseWrapper",
"LongTextProcessStage",
"SendResponseBackStage",
]
class StageInstContainer():
"""阶段实例容器
"""
inst_name: str
inst: stage.PipelineStage
def __init__(self, inst_name: str, inst: stage.PipelineStage):
self.inst_name = inst_name
self.inst = inst
class StageManager:
ap: app.Application
stage_containers: list[StageInstContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.stage_containers = []
async def initialize(self):
"""初始化
"""
for name, cls in stage._stage_classes.items():
self.stage_containers.append(StageInstContainer(
inst_name=name,
inst=cls(self.ap)
))
for stage_containers in self.stage_containers:
await stage_containers.inst.initialize()
# 按照 stage_order 排序
self.stage_containers.sort(key=lambda x: stage_order.index(x.inst_name))