feat: 限速算法的扩展性

This commit is contained in:
Junyan Qin 2024-03-12 16:31:54 +00:00
parent b9fa11c0c3
commit 13393b6624
4 changed files with 33 additions and 3 deletions

View File

@ -1,11 +1,26 @@
from __future__ import annotations
import abc
import typing
from ...core import app
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
def algo_class(name: str):
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
cls.name = name
preregistered_algos.append(cls)
return cls
return decorator
class ReteLimitAlgo(metaclass=abc.ABCMeta):
name: str = None
ap: app.Application
def __init__(self, ap: app.Application):

View File

@ -19,6 +19,7 @@ class SessionContainer:
self.records = {}
@algo.algo_class("fixwin")
class FixedWindowAlgo(algo.ReteLimitAlgo):
containers_lock: asyncio.Lock

View File

@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage):
algo: algo.ReteLimitAlgo
async def initialize(self):
self.algo = fixedwin.FixedWindowAlgo(self.ap)
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
algo_class = None
for algo_cls in algo.preregistered_algos:
if algo_cls.name == algo_name:
algo_class = algo_cls
break
else:
raise ValueError(f'未知的限速算法: {algo_name}')
self.algo = algo_class(self.ap)
await self.algo.initialize()
async def process(

View File

@ -22,14 +22,16 @@ class PromptManager:
mode_name = self.ap.provider_cfg.data['prompt-mode']
loader_class = None
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
loader_cls = loader_cls
loader_class = loader_cls
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()