QChatGPT/pkg/openai/dprompt.py
2024-01-12 16:48:47 +08:00

135 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 多情景预设值管理
import json
import logging
import os
from ..utils import context
# __current__ = "default"
# """当前默认使用的情景预设的名称
# 由管理员使用`!default <名称>`命令切换
# """
# __prompts_from_files__ = {}
# """从文件中读取的情景预设值"""
# __scenario_from_files__ = {}
class ScenarioMode:
"""情景预设模式抽象类"""
using_prompt_name = "default"
"""新session创建时使用的prompt名称"""
prompts: dict[str, list] = {}
def __init__(self):
logging.debug("prompts: {}".format(self.prompts))
def list(self) -> dict[str, list]:
"""获取所有情景预设的名称及内容"""
return self.prompts
def get_prompt(self, name: str) -> tuple[list, str]:
"""获取指定情景预设的名称及内容"""
for key in self.prompts:
if key.startswith(name):
return self.prompts[key], key
raise Exception("没有找到情景预设: {}".format(name))
def set_using_name(self, name: str) -> str:
"""设置默认情景预设"""
for key in self.prompts:
if key.startswith(name):
self.using_prompt_name = key
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_full_name(self, name: str) -> str:
"""获取完整的情景预设名称"""
for key in self.prompts:
if key.startswith(name):
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_using_name(self) -> str:
"""获取默认情景预设"""
return self.using_prompt_name
class NormalScenarioMode(ScenarioMode):
"""普通情景预设模式"""
def __init__(self):
config = context.get_config_manager().data
# 加载config中的default_prompt值
if type(config['default_prompt']) == str:
self.using_prompt_name = "default"
self.prompts = {"default": [
{
"role": "system",
"content": config['default_prompt']
}
]}
elif type(config['default_prompt']) == dict:
for key in config['default_prompt']:
self.prompts[key] = [
{
"role": "system",
"content": config['default_prompt'][key]
}
]
# 从prompts/目录下的文件中载入
# 遍历文件
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
self.prompts[file] = [
{
"role": "system",
"content": f.read()
}
]
class FullScenarioMode(ScenarioMode):
"""完整情景预设模式"""
def __init__(self):
"""从json读取所有"""
# 遍历scenario/目录下的所有文件以文件名为键文件内容中的prompt为值
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
self.prompts[file] = json.load(f)["prompt"]
super().__init__()
scenario_mode_mapping = {}
"""情景预设模式名称与对象的映射"""
def register_all():
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
global scenario_mode_mapping
scenario_mode_mapping = {
"normal": NormalScenarioMode(),
"full_scenario": FullScenarioMode()
}
def mode_inst() -> ScenarioMode:
"""获取指定名称的情景预设模式对象"""
config = context.get_config_manager().data
if config['preset_mode'] == "default":
config['preset_mode'] = "normal"
return scenario_mode_mapping[config['preset_mode']]