QChatGPT/pkg/openai/dprompt.py

146 lines
4.3 KiB
Python
Raw Normal View History

2023-02-19 11:46:12 +08:00
# 多情景预设值管理
2023-03-10 23:14:32 +08:00
import json
import logging
2023-03-26 21:28:26 +08:00
import config
import os
# __current__ = "default"
# """当前默认使用的情景预设的名称
# 由管理员使用`!default <名称>`指令切换
# """
# __prompts_from_files__ = {}
# """从文件中读取的情景预设值"""
# __scenario_from_files__ = {}
__universal_first_reply__ = "ok, I'll follow your commands."
"""通用首次回复"""
class ScenarioMode:
"""情景预设模式抽象类"""
using_prompt_name = "default"
"""新session创建时使用的prompt名称"""
prompts: dict[str, list] = {}
def __init__(self):
logging.debug("prompts: {}".format(self.prompts))
def list(self) -> dict[str, list]:
"""获取所有情景预设的名称及内容"""
return self.prompts
def get_prompt(self, name: str) -> tuple[list, str]:
"""获取指定情景预设的名称及内容"""
for key in self.prompts:
if key.startswith(name):
return self.prompts[key], key
raise Exception("没有找到情景预设: {}".format(name))
def set_using_name(self, name: str) -> str:
"""设置默认情景预设"""
for key in self.prompts:
if key.startswith(name):
self.using_prompt_name = key
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_full_name(self, name: str) -> str:
"""获取完整的情景预设名称"""
for key in self.prompts:
if key.startswith(name):
return key
raise Exception("没有找到情景预设: {}".format(name))
def get_using_name(self) -> str:
"""获取默认情景预设"""
return self.using_prompt_name
class NormalScenarioMode(ScenarioMode):
"""普通情景预设模式"""
def __init__(self):
global __universal_first_reply__
# 加载config中的default_prompt值
if type(config.default_prompt) == str:
self.using_prompt_name = "default"
self.prompts = {"default": [
{
"role": "user",
"content": config.default_prompt
},{
"role": "assistant",
"content": __universal_first_reply__
}
]}
elif type(config.default_prompt) == dict:
for key in config.default_prompt:
self.prompts[key] = [
{
"role": "user",
"content": config.default_prompt[key]
},{
"role": "assistant",
"content": __universal_first_reply__
}
]
2023-02-19 11:46:12 +08:00
2023-03-26 21:28:26 +08:00
# 从prompts/目录下的文件中载入
# 遍历文件
for file in os.listdir("prompts"):
with open(os.path.join("prompts", file), encoding="utf-8") as f:
self.prompts[file] = [
{
"role": "user",
"content": f.read()
},{
"role": "assistant",
"content": __universal_first_reply__
}
]
2023-03-26 21:28:26 +08:00
class FullScenarioMode(ScenarioMode):
"""完整情景预设模式"""
2023-02-19 11:46:12 +08:00
2023-03-26 21:28:26 +08:00
def __init__(self):
"""从json读取所有"""
# 遍历scenario/目录下的所有文件以文件名为键文件内容中的prompt为值
for file in os.listdir("scenario"):
if file == "default-template.json":
continue
with open(os.path.join("scenario", file), encoding="utf-8") as f:
self.prompts[file] = json.load(f)["prompt"]
2023-02-19 11:46:12 +08:00
2023-03-26 21:28:26 +08:00
super().__init__()
2023-02-19 11:46:12 +08:00
2023-03-26 21:28:26 +08:00
scenario_mode_mapping = {}
"""情景预设模式名称与对象的映射"""
2023-02-19 11:46:12 +08:00
2023-03-26 21:28:26 +08:00
def register_all():
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
global scenario_mode_mapping
scenario_mode_mapping = {
"normal": NormalScenarioMode(),
"full_scenario": FullScenarioMode()
}
2023-02-19 11:46:12 +08:00
2023-03-26 21:28:26 +08:00
def mode_inst() -> ScenarioMode:
"""获取指定名称的情景预设模式对象"""
import config
2023-02-19 11:46:12 +08:00
2023-03-26 21:28:26 +08:00
if config.preset_mode == "default":
config.preset_mode = "normal"
2023-03-26 21:28:26 +08:00
return scenario_mode_mapping[config.preset_mode]