2023-02-19 11:46:12 +08:00
|
|
|
# 多情景预设值管理
|
|
|
|
|
|
|
|
__current__ = "default"
|
|
|
|
|
2023-02-25 17:05:39 +08:00
|
|
|
__prompts_from_files__ = {}
|
|
|
|
|
|
|
|
|
|
|
|
def read_prompt_from_file() -> str:
|
|
|
|
"""从文件读取预设值"""
|
|
|
|
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值
|
|
|
|
# 保存在__prompts_from_files__中
|
|
|
|
global __prompts_from_files__
|
|
|
|
import os
|
|
|
|
|
|
|
|
__prompts_from_files__ = {}
|
|
|
|
for file in os.listdir("prompts"):
|
|
|
|
with open(os.path.join("prompts", file), encoding="utf-8") as f:
|
|
|
|
__prompts_from_files__[file] = f.read()
|
|
|
|
|
|
|
|
|
2023-02-19 11:46:12 +08:00
|
|
|
def get_prompt_dict() -> dict:
|
|
|
|
"""获取预设值字典"""
|
|
|
|
import config
|
|
|
|
default_prompt = config.default_prompt
|
|
|
|
if type(default_prompt) == str:
|
2023-02-25 17:05:39 +08:00
|
|
|
default_prompt = {"default": default_prompt}
|
2023-02-19 11:46:12 +08:00
|
|
|
elif type(default_prompt) == dict:
|
2023-02-25 17:05:39 +08:00
|
|
|
pass
|
2023-02-19 11:46:12 +08:00
|
|
|
else:
|
|
|
|
raise TypeError("default_prompt must be str or dict")
|
|
|
|
|
2023-02-25 17:05:39 +08:00
|
|
|
# 将文件中的预设值合并到default_prompt中
|
|
|
|
for key in __prompts_from_files__:
|
|
|
|
default_prompt[key] = __prompts_from_files__[key]
|
|
|
|
|
|
|
|
return default_prompt
|
|
|
|
|
2023-02-19 11:46:12 +08:00
|
|
|
|
|
|
|
def set_current(name):
|
|
|
|
global __current__
|
|
|
|
for key in get_prompt_dict():
|
|
|
|
if key.lower().startswith(name.lower()):
|
|
|
|
__current__ = key
|
|
|
|
return
|
|
|
|
raise KeyError("未找到情景预设: " + name)
|
|
|
|
|
|
|
|
|
|
|
|
def get_current():
|
|
|
|
global __current__
|
|
|
|
return __current__
|
|
|
|
|
|
|
|
|
|
|
|
def set_to_default():
|
|
|
|
global __current__
|
|
|
|
default_dict = get_prompt_dict()
|
|
|
|
|
|
|
|
if "default" in default_dict:
|
|
|
|
__current__ = "default"
|
|
|
|
else:
|
|
|
|
__current__ = list(default_dict.keys())[0]
|
|
|
|
|
|
|
|
|
|
|
|
def get_prompt(name: str = None) -> str:
|
|
|
|
"""获取预设值"""
|
|
|
|
if name is None:
|
|
|
|
name = get_current()
|
|
|
|
|
|
|
|
default_dict = get_prompt_dict()
|
|
|
|
|
|
|
|
for key in default_dict:
|
|
|
|
if key.lower().startswith(name.lower()):
|
|
|
|
return default_dict[key]
|
2023-02-25 17:05:39 +08:00
|
|
|
|
2023-02-19 11:46:12 +08:00
|
|
|
raise KeyError("未找到情景预设: " + name)
|