mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
feat: 限速算法的扩展性
This commit is contained in:
parent
b9fa11c0c3
commit
13393b6624
|
@ -1,11 +1,26 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
|
import typing
|
||||||
|
|
||||||
from ...core import app
|
from ...core import app
|
||||||
|
|
||||||
|
|
||||||
|
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||||
|
|
||||||
|
def algo_class(name: str):
|
||||||
|
|
||||||
|
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
|
||||||
|
cls.name = name
|
||||||
|
preregistered_algos.append(cls)
|
||||||
|
return cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||||
|
|
||||||
|
name: str = None
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
def __init__(self, ap: app.Application):
|
def __init__(self, ap: app.Application):
|
||||||
|
|
|
@ -19,6 +19,7 @@ class SessionContainer:
|
||||||
self.records = {}
|
self.records = {}
|
||||||
|
|
||||||
|
|
||||||
|
@algo.algo_class("fixwin")
|
||||||
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||||
|
|
||||||
containers_lock: asyncio.Lock
|
containers_lock: asyncio.Lock
|
||||||
|
|
|
@ -16,7 +16,19 @@ class RateLimit(stage.PipelineStage):
|
||||||
algo: algo.ReteLimitAlgo
|
algo: algo.ReteLimitAlgo
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
self.algo = fixedwin.FixedWindowAlgo(self.ap)
|
|
||||||
|
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
|
||||||
|
|
||||||
|
algo_class = None
|
||||||
|
|
||||||
|
for algo_cls in algo.preregistered_algos:
|
||||||
|
if algo_cls.name == algo_name:
|
||||||
|
algo_class = algo_cls
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError(f'未知的限速算法: {algo_name}')
|
||||||
|
|
||||||
|
self.algo = algo_class(self.ap)
|
||||||
await self.algo.initialize()
|
await self.algo.initialize()
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
|
|
|
@ -22,14 +22,16 @@ class PromptManager:
|
||||||
|
|
||||||
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
||||||
|
|
||||||
|
loader_class = None
|
||||||
|
|
||||||
for loader_cls in loader.preregistered_loaders:
|
for loader_cls in loader.preregistered_loaders:
|
||||||
if loader_cls.name == mode_name:
|
if loader_cls.name == mode_name:
|
||||||
loader_cls = loader_cls
|
loader_class = loader_cls
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
||||||
|
|
||||||
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)
|
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
|
||||||
|
|
||||||
await self.loader_inst.initialize()
|
await self.loader_inst.initialize()
|
||||||
await self.loader_inst.load()
|
await self.loader_inst.load()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user