2024-01-27 00:06:38 +08:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from ...core import app
|
|
|
|
from . import loader
|
|
|
|
from .loaders import single, scenario
|
|
|
|
|
|
|
|
|
|
|
|
class PromptManager:
|
2024-03-03 16:34:59 +08:00
|
|
|
"""Prompt管理器
|
|
|
|
"""
|
2024-01-27 00:06:38 +08:00
|
|
|
|
|
|
|
ap: app.Application
|
|
|
|
|
|
|
|
loader_inst: loader.PromptLoader
|
|
|
|
|
|
|
|
default_prompt: str = 'default'
|
|
|
|
|
|
|
|
def __init__(self, ap: app.Application):
|
|
|
|
self.ap = ap
|
|
|
|
|
|
|
|
async def initialize(self):
|
|
|
|
|
2024-03-13 00:22:07 +08:00
|
|
|
mode_name = self.ap.provider_cfg.data['prompt-mode']
|
|
|
|
|
2024-03-13 00:31:54 +08:00
|
|
|
loader_class = None
|
|
|
|
|
2024-03-13 00:22:07 +08:00
|
|
|
for loader_cls in loader.preregistered_loaders:
|
|
|
|
if loader_cls.name == mode_name:
|
2024-03-13 00:31:54 +08:00
|
|
|
loader_class = loader_cls
|
2024-03-13 00:22:07 +08:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
|
2024-01-27 00:06:38 +08:00
|
|
|
|
2024-03-13 00:31:54 +08:00
|
|
|
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
|
2024-01-27 00:06:38 +08:00
|
|
|
|
|
|
|
await self.loader_inst.initialize()
|
|
|
|
await self.loader_inst.load()
|
|
|
|
|
|
|
|
def get_all_prompts(self) -> list[loader.entities.Prompt]:
|
|
|
|
"""获取所有Prompt
|
|
|
|
"""
|
|
|
|
return self.loader_inst.get_prompts()
|
|
|
|
|
2024-01-28 18:21:43 +08:00
|
|
|
async def get_prompt(self, name: str) -> loader.entities.Prompt:
|
2024-01-27 00:06:38 +08:00
|
|
|
"""获取Prompt
|
|
|
|
"""
|
|
|
|
for prompt in self.get_all_prompts():
|
|
|
|
if prompt.name == name:
|
|
|
|
return prompt
|
2024-01-28 18:21:43 +08:00
|
|
|
|
|
|
|
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
|
|
|
|
"""通过前缀获取Prompt
|
|
|
|
"""
|
|
|
|
for prompt in self.get_all_prompts():
|
|
|
|
if prompt.name.startswith(prefix):
|
|
|
|
return prompt
|