mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 19:57:04 +08:00
135 lines
4.0 KiB
Python
135 lines
4.0 KiB
Python
# 多情景预设值管理
|
||
import json
|
||
import logging
|
||
import os
|
||
|
||
from ..utils import context
|
||
|
||
# __current__ = "default"
|
||
# """当前默认使用的情景预设的名称
|
||
|
||
# 由管理员使用`!default <名称>`命令切换
|
||
# """
|
||
|
||
# __prompts_from_files__ = {}
|
||
# """从文件中读取的情景预设值"""
|
||
|
||
# __scenario_from_files__ = {}
|
||
|
||
|
||
class ScenarioMode:
|
||
"""情景预设模式抽象类"""
|
||
|
||
using_prompt_name = "default"
|
||
"""新session创建时使用的prompt名称"""
|
||
|
||
prompts: dict[str, list] = {}
|
||
|
||
def __init__(self):
|
||
logging.debug("prompts: {}".format(self.prompts))
|
||
|
||
def list(self) -> dict[str, list]:
|
||
"""获取所有情景预设的名称及内容"""
|
||
return self.prompts
|
||
|
||
def get_prompt(self, name: str) -> tuple[list, str]:
|
||
"""获取指定情景预设的名称及内容"""
|
||
for key in self.prompts:
|
||
if key.startswith(name):
|
||
return self.prompts[key], key
|
||
raise Exception("没有找到情景预设: {}".format(name))
|
||
|
||
def set_using_name(self, name: str) -> str:
|
||
"""设置默认情景预设"""
|
||
for key in self.prompts:
|
||
if key.startswith(name):
|
||
self.using_prompt_name = key
|
||
return key
|
||
raise Exception("没有找到情景预设: {}".format(name))
|
||
|
||
def get_full_name(self, name: str) -> str:
|
||
"""获取完整的情景预设名称"""
|
||
for key in self.prompts:
|
||
if key.startswith(name):
|
||
return key
|
||
raise Exception("没有找到情景预设: {}".format(name))
|
||
|
||
def get_using_name(self) -> str:
|
||
"""获取默认情景预设"""
|
||
return self.using_prompt_name
|
||
|
||
|
||
class NormalScenarioMode(ScenarioMode):
|
||
"""普通情景预设模式"""
|
||
|
||
def __init__(self):
|
||
config = context.get_config_manager().data
|
||
|
||
# 加载config中的default_prompt值
|
||
if type(config['default_prompt']) == str:
|
||
self.using_prompt_name = "default"
|
||
self.prompts = {"default": [
|
||
{
|
||
"role": "system",
|
||
"content": config['default_prompt']
|
||
}
|
||
]}
|
||
|
||
elif type(config['default_prompt']) == dict:
|
||
for key in config['default_prompt']:
|
||
self.prompts[key] = [
|
||
{
|
||
"role": "system",
|
||
"content": config['default_prompt'][key]
|
||
}
|
||
]
|
||
|
||
# 从prompts/目录下的文件中载入
|
||
# 遍历文件
|
||
for file in os.listdir("prompts"):
|
||
with open(os.path.join("prompts", file), encoding="utf-8") as f:
|
||
self.prompts[file] = [
|
||
{
|
||
"role": "system",
|
||
"content": f.read()
|
||
}
|
||
]
|
||
|
||
|
||
class FullScenarioMode(ScenarioMode):
|
||
"""完整情景预设模式"""
|
||
|
||
def __init__(self):
|
||
"""从json读取所有"""
|
||
# 遍历scenario/目录下的所有文件,以文件名为键,文件内容中的prompt为值
|
||
for file in os.listdir("scenario"):
|
||
if file == "default-template.json":
|
||
continue
|
||
with open(os.path.join("scenario", file), encoding="utf-8") as f:
|
||
self.prompts[file] = json.load(f)["prompt"]
|
||
|
||
super().__init__()
|
||
|
||
|
||
scenario_mode_mapping = {}
|
||
"""情景预设模式名称与对象的映射"""
|
||
|
||
|
||
def register_all():
|
||
"""注册所有情景预设模式,不使用装饰器,因为装饰器的方式不支持热重载"""
|
||
global scenario_mode_mapping
|
||
scenario_mode_mapping = {
|
||
"normal": NormalScenarioMode(),
|
||
"full_scenario": FullScenarioMode()
|
||
}
|
||
|
||
|
||
def mode_inst() -> ScenarioMode:
|
||
"""获取指定名称的情景预设模式对象"""
|
||
config = context.get_config_manager().data
|
||
|
||
if config['preset_mode'] == "default":
|
||
config['preset_mode'] = "normal"
|
||
|
||
return scenario_mode_mapping[config['preset_mode']]
|