QChatGPT/pkg/openai/dprompt.py

122 lines
3.2 KiB
Python
Raw Normal View History

2023-02-19 11:46:12 +08:00
# 多情景预设值管理
2023-03-10 23:14:32 +08:00
import json
import logging
2023-02-19 11:46:12 +08:00
__current__ = "default"
2023-03-05 15:39:13 +08:00
"""当前默认使用的情景预设的名称
由管理员使用`!default <名称>`指令切换
"""
2023-02-19 11:46:12 +08:00
__prompts_from_files__ = {}
2023-03-05 15:39:13 +08:00
"""从文件中读取的情景预设值"""
__scenario_from_files__ = {}
2023-03-10 23:14:32 +08:00
def read_prompt_from_file():
"""从文件读取预设值"""
# 读取prompts/目录下的所有文件,以文件名为键,文件内容为值
# 保存在__prompts_from_files__中
global __prompts_from_files__
import os
__prompts_from_files__ = {}
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
__prompts_from_files__[file] = f.read()
def read_scenario_from_file():
"""从JSON文件读取情景预设"""
global __scenario_from_files__
import os
__scenario_from_files__ = {}
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
__scenario_from_files__[file] = json.load(f)
2023-02-19 11:46:12 +08:00
def get_prompt_dict() -> dict:
"""获取预设值字典"""
import config
default_prompt = config.default_prompt
if type(default_prompt) == str:
default_prompt = {"default": default_prompt}
2023-02-19 11:46:12 +08:00
elif type(default_prompt) == dict:
pass
2023-02-19 11:46:12 +08:00
else:
raise TypeError("default_prompt must be str or dict")
# 将文件中的预设值合并到default_prompt中
for key in __prompts_from_files__:
default_prompt[key] = __prompts_from_files__[key]
return default_prompt
2023-02-19 11:46:12 +08:00
def set_current(name):
global __current__
for key in get_prompt_dict():
if key.lower().startswith(name.lower()):
__current__ = key
return
raise KeyError("未找到情景预设: " + name)
def get_current():
global __current__
return __current__
def set_to_default():
global __current__
default_dict = get_prompt_dict()
if "default" in default_dict:
__current__ = "default"
else:
__current__ = list(default_dict.keys())[0]
def get_prompt(name: str = None) -> list:
global __scenario_from_files__
import config
preset_mode = config.preset_mode
2023-02-19 11:46:12 +08:00
"""获取预设值"""
if name is None:
name = get_current()
# JSON预设方式
if preset_mode == 'full_scenario':
import os
for key in __scenario_from_files__:
if key.lower().startswith(name.lower()):
logging.debug('成功加载情景预设从JSON文件: {}'.format(key))
return __scenario_from_files__[key]['prompt']
# 默认预设方式
elif preset_mode == 'default':
default_dict = get_prompt_dict()
for key in default_dict:
if key.lower().startswith(name.lower()):
return [
{
2023-03-10 23:14:32 +08:00
"role": "user",
"content": default_dict[key]
},
{
2023-03-10 23:14:32 +08:00
"role": "assistant",
"content": "好的。"
}
2023-03-10 23:14:32 +08:00
]
raise KeyError("未找到默认情景预设: " + name)