QChatGPT/pkg/provider/sysprompt/sysprompt.py

57 lines
1.5 KiB
Python
Raw Normal View History

from __future__ import annotations
from ...core import app
from . import loader
from .loaders import single, scenario
class PromptManager:
2024-03-03 16:34:59 +08:00
"""Prompt管理器
"""
ap: app.Application
loader_inst: loader.PromptLoader
default_prompt: str = 'default'
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
2024-03-13 00:22:07 +08:00
mode_name = self.ap.provider_cfg.data['prompt-mode']
2024-03-13 00:31:54 +08:00
loader_class = None
2024-03-13 00:22:07 +08:00
for loader_cls in loader.preregistered_loaders:
if loader_cls.name == mode_name:
2024-03-13 00:31:54 +08:00
loader_class = loader_cls
2024-03-13 00:22:07 +08:00
break
else:
raise ValueError(f'未知的 Prompt 加载器: {mode_name}')
2024-03-13 00:31:54 +08:00
self.loader_inst: loader.PromptLoader = loader_class(self.ap)
await self.loader_inst.initialize()
await self.loader_inst.load()
def get_all_prompts(self) -> list[loader.entities.Prompt]:
"""获取所有Prompt
"""
return self.loader_inst.get_prompts()
2024-01-28 18:21:43 +08:00
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt
2024-01-28 18:21:43 +08:00
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
"""通过前缀获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name.startswith(prefix):
return prompt