refactor: 配置文件均改为json

This commit is contained in:
RockChinQ 2024-02-06 21:26:03 +08:00
parent f340a44abf
commit c853bba4ba
48 changed files with 355 additions and 285 deletions

1
.gitignore vendored
View File

@ -33,3 +33,4 @@ bard.json
!/docker-compose.yaml
res/instance_id.json
.DS_Store
/data

View File

@ -12,8 +12,7 @@ class V2MainDataAPI(apigroup.APIGroup):
super().__init__(prefix+"/main", ap)
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data
if not config['report_usage']:
if not self.ap.system_cfg.data['report-usage']:
return None
return await super().do(*args, **kwargs)

View File

@ -12,8 +12,7 @@ class V2PluginDataAPI(apigroup.APIGroup):
super().__init__(prefix+"/plugin", ap)
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data
if not config['report_usage']:
if not self.ap.system_cfg.data['report-usage']:
return None
return await super().do(*args, **kwargs)

View File

@ -12,8 +12,7 @@ class V2UsageDataAPI(apigroup.APIGroup):
super().__init__(prefix+"/usage", ap)
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data
if not config['report_usage']:
if not self.ap.system_cfg.data['report-usage']:
return None
return await super().do(*args, **kwargs)

View File

@ -6,7 +6,7 @@ from ..core import app, entities as core_entities
from ..provider import entities as llm_entities
from . import entities, operator, errors
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
class CommandManager:
@ -85,9 +85,11 @@ class CommandManager:
"""
privilege = 1
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
privilege = 2
print(f'privilege: {privilege}')
ctx = entities.ExecuteContext(
query=query,

View File

@ -24,6 +24,7 @@ def operator_class(
cls.help = help
cls.usage = usage
cls.parent_class = parent_class
cls.lowest_privilege = privilege
preregistered_operators.append(cls)

View File

@ -1,98 +0,0 @@
from __future__ import annotations
import typing
import json
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="cfg",
help="配置项管理",
usage='!cfg <配置项> [配置值]\n!cfg all',
privilege=2
)
class CfgOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
reply = ''
params = context.crt_params
cfg_mgr = self.ap.cfg_mgr
false = False
true = True
reply_str = ""
if len(params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供配置项名称'))
else:
cfg_name = params[0]
if cfg_name == 'all':
reply_str = "[bot]所有配置项:\n\n"
for cfg in cfg_mgr.data.keys():
if not cfg.startswith('__') and not cfg == 'logging':
# 根据配置项类型进行格式化如果是字典则转换为json并格式化
if isinstance(cfg_mgr.data[cfg], str):
reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg])
elif isinstance(cfg_mgr.data[cfg], dict):
# 不进行unicode转义并格式化
reply_str += "{}: {}\n".format(cfg,
json.dumps(cfg_mgr.data[cfg],
ensure_ascii=False, indent=4))
else:
reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg])
yield entities.CommandReturn(text=reply_str)
else:
cfg_entry_path = cfg_name.split('.')
try:
if len(params) == 1: # 未指定配置值,返回配置项值
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
if len(cfg_entry_path) > 1:
for i in range(1, len(cfg_entry_path)):
cfg_entry = cfg_entry[cfg_entry_path[i]]
if isinstance(cfg_entry, str):
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry)
elif isinstance(cfg_entry, dict):
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
json.dumps(cfg_entry,
ensure_ascii=False, indent=4))
else:
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry)
yield entities.CommandReturn(text=reply_str)
else:
cfg_value = " ".join(params[1:])
cfg_value = eval(cfg_value)
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
if len(cfg_entry_path) > 1:
for i in range(1, len(cfg_entry_path) - 1):
cfg_entry = cfg_entry[cfg_entry_path[i]]
if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)):
cfg_entry[cfg_entry_path[-1]] = cfg_value
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
else:
# reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("配置项{}类型不匹配".format(cfg_name)))
else:
cfg_mgr.data[cfg_entry_path[0]] = cfg_value
# reply = ["[bot]配置项{}修改成功".format(cfg_name)]
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
except KeyError:
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))
except NameError:
# reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)]
yield entities.CommandReturn(error=errors.CommandOperationError("{}不合法(字符串需要使用双引号包裹)".format(cfg_value)))
except ValueError:
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))

View File

@ -16,7 +16,7 @@ class HelpOperator(operator.CommandOperator):
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = self.ap.tips_mgr.data['help_message']
help = self.ap.system_cfg.data['help-message']
help += '\n发送命令 !cmd 可查看命令列表'

View File

@ -4,6 +4,9 @@ from . import model as file_model
from .impls import pymodule, json as json_file
managers: ConfigManager = []
class ConfigManager:
"""配置文件管理器"""

View File

@ -31,9 +31,15 @@ class Application:
tool_mgr: llm_tool_mgr.ToolManager = None
cfg_mgr: config_mgr.ConfigManager = None
command_cfg: config_mgr.ConfigManager = None
tips_mgr: config_mgr.ConfigManager = None
pipeline_cfg: config_mgr.ConfigManager = None
platform_cfg: config_mgr.ConfigManager = None
provider_cfg: config_mgr.ConfigManager = None
system_cfg: config_mgr.ConfigManager = None
ctr_mgr: center_mgr.V2CenterAPI = None

View File

@ -32,7 +32,7 @@ async def make_app() -> app.Application:
generated_files = await files.generate_files()
if generated_files:
print("以下文件不存在,已自动生成,请修改配置文件后重启:")
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
for file in generated_files:
print("-", file)
@ -52,31 +52,23 @@ async def make_app() -> app.Application:
# 生成标识符
identifier.init()
cfg_mgr = await config.load_python_module_config("config.py", "config-template.py")
cfg = cfg_mgr.data
# ========== 加载配置文件 ==========
# 检查是否携带了 --override 或 -r 参数
if "--override" in sys.argv or "-r" in sys.argv:
use_override = True
command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
if use_override:
overrided = await config.override_config_manager(cfg_mgr)
if overrided:
qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided))
tips_mgr = await config.load_python_module_config(
"tips.py", "tips-custom-template.py"
)
# 检查管理员QQ号
if cfg_mgr.data["admin_qq"] == 0:
qcg_logger.warning("未设置管理员QQ号将无法使用管理员命令请在 config.py 中修改 admin_qq")
# 构建组建实例
# ========== 构建应用实例 ==========
ap = app.Application()
ap.logger = qcg_logger
ap.cfg_mgr = cfg_mgr
ap.tips_mgr = tips_mgr
ap.command_cfg = command_cfg
ap.pipeline_cfg = pipeline_cfg
ap.platform_cfg = platform_cfg
ap.provider_cfg = provider_cfg
ap.system_cfg = system_cfg
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
@ -95,8 +87,8 @@ async def make_app() -> app.Application:
"platform": sys.platform,
},
runtime_info={
"admin_id": "{}".format(cfg["admin_qq"]),
"msg_source": cfg["msg_source_adapter"],
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
"msg_source": platform_cfg.data["platform-adapter"],
},
)
ap.ctr_mgr = center_v2_api

View File

@ -6,19 +6,25 @@ import sys
required_files = {
"config.py": "config-template.py",
"banlist.py": "banlist-template.py",
"tips.py": "tips-custom-template.py",
"sensitive.json": "res/templates/sensitive-template.json",
"scenario/default.json": "scenario/default-template.json",
"cmdpriv.json": "res/templates/cmdpriv-template.json",
"plugins/__init__.py": "templates/__init__.py",
"plugins/plugins.json": "templates/plugin-settings.json",
"data/config/command.json": "templates/command.json",
"data/config/pipeline.json": "templates/pipeline.json",
"data/config/platform.json": "templates/platform.json",
"data/config/provider.json": "templates/provider.json",
"data/config/system.json": "templates/system.json",
"data/config/sensitive-words.json": "templates/sensitive-words.json",
"data/scenario/default.json": "templates/scenario-template.json",
}
required_paths = [
"plugins",
"prompts",
"temp",
"logs"
"data",
"data/prompts",
"data/scenario",
"data/logs",
"data/config",
"plugins"
]
async def generate_files() -> list[str]:

View File

@ -21,7 +21,7 @@ async def init_logging() -> logging.Logger:
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
level = logging.DEBUG
log_file_name = "logs/qcg-%s.log" % time.strftime(
log_file_name = "data/logs/qcg-%s.log" % time.strftime(
"%Y-%m-%d-%H-%M-%S", time.localtime()
)

View File

@ -8,8 +8,6 @@ from . import app, entities
from ..pipeline import entities as pipeline_entities
from ..plugin import events
DEFAULT_QUERY_CONCURRENCY = 10
class Controller:
"""总控制器
@ -21,7 +19,7 @@ class Controller:
def __init__(self, ap: app.Application):
self.ap = ap
self.semaphore = asyncio.Semaphore(DEFAULT_QUERY_CONCURRENCY)
self.semaphore = asyncio.Semaphore(self.ap.system_cfg.data['pipeline-concurrency'])
async def consumer(self):
"""事件处理循环
@ -150,9 +148,9 @@ class Controller:
try:
await self._execute_from_stage(0, query)
except Exception as e:
self.ap.logger.error(f"处理请求时出错 {query}: {e}")
# self.ap.logger.debug(f"处理请求时出错 {query}: {e}", exc_info=True)
traceback.print_exc()
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id}: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
# traceback.print_exc()
finally:
self.ap.logger.debug(f"Query {query} processed")

View File

@ -23,54 +23,28 @@ class BanSessionCheckStage(stage.PipelineStage):
stage_inst_name: str
) -> entities.StageProcessResult:
if not self.banlist_mgr.data['enable']:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
found = False
mode = self.ap.pipeline_cfg.data['access-control']['mode']
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
if (query.launcher_type == 'group' and 'group_*' in sess_list) \
or (query.launcher_type == 'person' and 'person_*' in sess_list):
found = True
else:
for sess in sess_list:
if sess == f"{query.launcher_type}_{query.launcher_id}":
found = True
break
result = False
if query.launcher_type == 'group':
if not self.banlist_mgr.data['enable_group']: # 未启用群聊响应
result = True
# 检查是否显式声明发起人QQ要被person忽略
elif query.sender_id in self.banlist_mgr.data['person']:
result = True
else:
for group_rule in self.banlist_mgr.data['group']:
if type(group_rule) == int:
if group_rule == query.launcher_id:
result = True
elif type(group_rule) == str:
if group_rule.startswith('!'):
reg_str = group_rule[1:]
if re.match(reg_str, str(query.launcher_id)):
result = False
break
else:
if re.match(group_rule, str(query.launcher_id)):
result = True
elif query.launcher_type == 'person':
if not self.banlist_mgr.data['enable_private']:
result = True
else:
for person_rule in self.banlist_mgr.data['person']:
if type(person_rule) == int:
if person_rule == query.launcher_id:
result = True
elif type(person_rule) == str:
if person_rule.startswith('!'):
reg_str = person_rule[1:]
if re.match(reg_str, str(query.launcher_id)):
result = False
break
else:
if re.match(person_rule, str(query.launcher_id)):
result = True
if mode == 'blacklist':
result = found
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
new_query=query,
debug_notice=f'根据禁用列表忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
)

View File

@ -24,10 +24,10 @@ class ContentFilterStage(stage.PipelineStage):
async def initialize(self):
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
if self.ap.cfg_mgr.data['sensitive_word_filter']:
if self.ap.pipeline_cfg.data['check-sensitive-words']:
self.filter_chain.append(banwords.BanWordFilter(self.ap))
if self.ap.cfg_mgr.data['baidu_check']:
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
for filter in self.filter_chain:
@ -41,7 +41,7 @@ class ContentFilterStage(stage.PipelineStage):
"""请求llm前处理消息
只要有一个不通过就不放行只放行 PASS 的消息
"""
if not self.ap.cfg_mgr.data['income_msg_check']:
if not self.ap.pipeline_cfg.data['income-msg-check']:
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@ -19,8 +19,8 @@ class BaiduCloudExamine(filter_model.ContentFilter):
BAIDU_EXAMINE_TOKEN_URL,
params={
"grant_type": "client_credentials",
"client_id": self.ap.cfg_mgr.data['baidu_api_key'],
"client_secret": self.ap.cfg_mgr.data['baidu_secret_key']
"client_id": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-key'],
"client_secret": self.ap.pipeline_cfg.data['baidu-cloud-examine']['api-secret']
}
) as resp:
return (await resp.json())['access_token']
@ -56,6 +56,6 @@ class BaiduCloudExamine(filter_model.ContentFilter):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
replacement=message,
user_notice=self.ap.cfg_mgr.data['inappropriate_message_tips'],
user_notice="消息中存在不合适的内容, 请修改",
console_notice=f"百度云判定结果:{conclusion}"
)
)

View File

@ -13,8 +13,8 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self):
self.sensitive = await cfg_mgr.load_json_config(
"sensitive.json",
"res/templates/sensitive-template.json"
"data/config/sensitive-words.json",
"templates/sensitive-words.json"
)
async def process(self, message: str) -> entities.FilterResult:
@ -39,6 +39,6 @@ class BanWordFilter(filter_model.ContentFilter):
return entities.FilterResult(
level=entities.ResultLevel.MASKED if found else entities.ResultLevel.PASS,
replacement=message,
user_notice='[bot] 消息中存在不合适的内容, 请更换措辞' if found else '',
user_notice='消息中存在不合适的内容, 请修改' if found else '',
console_notice=''
)

View File

@ -15,8 +15,8 @@ class ContentIgnore(filter_model.ContentFilter):
]
async def process(self, message: str) -> entities.FilterResult:
if 'prefix' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['prefix']:
if 'prefix' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['prefix']:
if message.startswith(rule):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,
@ -25,8 +25,8 @@ class ContentIgnore(filter_model.ContentFilter):
console_notice='根据 ignore_rules 中的 prefix 规则,忽略消息'
)
if 'regexp' in self.ap.cfg_mgr.data['ignore_rules']:
for rule in self.ap.cfg_mgr.data['ignore_rules']['regexp']:
if 'regexp' in self.ap.pipeline_cfg.data['ignore-rules']:
for rule in self.ap.pipeline_cfg.data['ignore-rules']['regexp']:
if re.search(rule, message):
return entities.FilterResult(
level=entities.ResultLevel.BLOCK,

View File

@ -19,9 +19,9 @@ class LongTextProcessStage(stage.PipelineStage):
strategy_impl: strategy.LongTextStrategy
async def initialize(self):
config = self.ap.cfg_mgr.data
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
use_font = config['font_path']
config = self.ap.platform_cfg.data['long-text-process']
if config['strategy'] == 'image':
use_font = config['font-path']
try:
# 检查是否存在
if not os.path.exists(use_font):
@ -33,23 +33,25 @@ class LongTextProcessStage(stage.PipelineStage):
config['blob_message_strategy'] = "forward"
else:
self.ap.logger.info("使用Windows自带字体" + use_font)
self.ap.cfg_mgr.data['font_path'] = use_font
config['font-path'] = use_font
else:
self.ap.logger.warn("未找到字体文件且无法使用系统自带字体更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。")
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
except:
traceback.print_exc()
self.ap.logger.error("加载字体文件失败({})更换为转发消息组件以发送长消息您可以在config.py中调整相关设置。".format(use_font))
self.ap.cfg_mgr.data['blob_message_strategy'] = "forward"
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
if self.ap.cfg_mgr.data['blob_message_strategy'] == 'image':
if config['strategy'] == 'image':
self.strategy_impl = image.Text2ImageStrategy(self.ap)
elif self.ap.cfg_mgr.data['blob_message_strategy'] == 'forward':
elif config['strategy'] == 'forward':
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if len(str(query.resp_message_chain)) > self.ap.cfg_mgr.data['blob_message_threshold']:
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain)))
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@ -19,7 +19,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_render_font: ImageFont.FreeTypeFont
async def initialize(self):
self.text_render_font = ImageFont.truetype(self.ap.cfg_mgr.data['font_path'], 32, encoding="utf-8")
self.text_render_font = ImageFont.truetype(self.ap.platform_cfg.data['long-text-process']['font-path'], 32, encoding="utf-8")
async def process(self, message: str) -> list[MessageComponent]:
img_path = self.text_to_image(

View File

@ -52,7 +52,7 @@ class PreProcessor(stage.PipelineStage):
query.messages = event_ctx.event.prompt
# 根据模型max_tokens剪裁
max_tokens = min(query.use_model.max_tokens, self.ap.cfg_mgr.data['prompt_submit_length'])
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
test_messages = query.prompt.messages + query.messages + [query.user_message]
@ -63,7 +63,7 @@ class PreProcessor(stage.PipelineStage):
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 prompt_submit_length但不能超过所用模型最大tokens数'
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens但不能超过所用模型最大tokens数'
)
query.messages.pop(0) # pop第一个肯定是role=user的

View File

@ -54,6 +54,12 @@ class ChatMessageHandler(handler.MessageHandler):
)
else:
if not self.ap.provider_cfg.data['enable-chat']:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
@ -83,7 +89,7 @@ class ChatMessageHandler(handler.MessageHandler):
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice=self.ap.tips_mgr.data['alter_tip_message'] if self.ap.cfg_mgr.data['hide_exce_info_to_user'] else f'{e}',
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc()
)

View File

@ -23,8 +23,8 @@ class CommandHandler(handler.MessageHandler):
privilege = 1
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
privilege = 2
spt = str(query.message_chain).strip().split(' ')

View File

@ -55,16 +55,16 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 获取当前分钟的访问次数
count = container.records.get(now, 0)
limitation = self.ap.cfg_mgr.data['rate_limitation']['default']
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']
if session_name in self.ap.cfg_mgr.data['rate_limitation']:
limitation = self.ap.cfg_mgr.data['rate_limitation'][session_name]
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]
# 如果访问次数超过了限制
if count >= limitation:
if self.ap.cfg_mgr.data['rate_limit_strategy'] == 'drop':
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
return False
elif self.ap.cfg_mgr.data['rate_limit_strategy'] == 'wait':
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
# 等待下一分钟
await asyncio.sleep(60 - time.time() % 60)

View File

@ -42,7 +42,7 @@ class RateLimit(stage.PipelineStage):
result_type=entities.ResultType.INTERRUPT,
new_query=query,
console_notice=f"根据限速规则忽略 {query.launcher_type.value}:{query.launcher_id} 消息",
user_notice=self.ap.tips_mgr.data['rate_limit_drop_tip']
user_notice=f"请求数超过限速器设定值,已丢弃本消息。"
)
elif stage_inst_name == "ReleaseRateLimitOccupancy":
await self.algo.release_access(

View File

@ -20,7 +20,7 @@ class SendResponseBackStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理
"""
random_delay = random.uniform(*self.ap.cfg_mgr.data['force_delay_range'])
random_delay = random.uniform(*self.ap.platform_cfg.data['force-delay'])
self.ap.logger.debug(
"根据规则强制延迟回复: %s s",

View File

@ -33,13 +33,13 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type != 'group':
if query.launcher_type.value != 'group':
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
rules = self.ap.cfg_mgr.data['response_rules']
rules = self.ap.pipeline_cfg.data['respond-rules']
use_rule = rules['default']

View File

@ -16,6 +16,7 @@ class PrefixRule(rule_model.GroupRespondRule):
for prefix in prefixes:
if message_text.startswith(prefix):
return entities.RuleJudgeResult(
matching=True,
replacement=mirai.MessageChain([

View File

@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
message_chain: mirai.MessageChain,
rule_dict: dict
) -> entities.RuleJudgeResult:
random_rate = rule_dict['random_rate']
random_rate = rule_dict['random']
return entities.RuleJudgeResult(
matching=random.random() < random_rate,

View File

@ -85,7 +85,7 @@ class ResponseWrapper(stage.PipelineStage):
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
if self.ap.cfg_mgr.data['trace_function_calls']:
if self.ap.platform_cfg.data['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event(
event=events.NormalMessageResponded(

View File

@ -31,19 +31,16 @@ class PlatformManager:
async def initialize(self):
config = self.ap.cfg_mgr.data
logging.debug("Use adapter:" + config['msg_source_adapter'])
if config['msg_source_adapter'] == 'yirimirai':
if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai':
from pkg.platform.sources.yirimirai import YiriMiraiAdapter
mirai_http_api_config = config['mirai_http_api_config']
self.bot_account_id = config['mirai_http_api_config']['qq']
mirai_http_api_config = self.ap.platform_cfg.data['yiri-mirai-config']
self.bot_account_id = mirai_http_api_config['qq']
self.adapter = YiriMiraiAdapter(mirai_http_api_config)
elif config['msg_source_adapter'] == 'nakuru':
from pkg.platform.sources.nakuru import NakuruProjectAdapter
self.adapter = NakuruProjectAdapter(config['nakuru_config'])
self.bot_account_id = self.adapter.bot_account_id
# elif config['msg_source_adapter'] == 'nakuru':
# from pkg.platform.sources.nakuru import NakuruProjectAdapter
# self.adapter = NakuruProjectAdapter(config['nakuru_config'])
# self.bot_account_id = self.adapter.bot_account_id
# 保存 account_id 到审计模块
from ..audit.center import apigroup
@ -99,7 +96,7 @@ class PlatformManager:
)
# nakuru不区分好友和陌生人故仅为yirimirai注册陌生人事件
if config['msg_source_adapter'] == 'yirimirai':
if self.ap.platform_cfg.data['platform-adapter'] == 'yiri-mirai':
self.adapter.register_listener(
StrangerMessage,
on_stranger_message
@ -133,27 +130,26 @@ class PlatformManager:
)
async def send(self, event, msg, check_quote=True, check_at_sender=True):
config = self.ap.cfg_mgr.data
if check_at_sender and config['at_sender']:
if check_at_sender and self.ap.platform_cfg.data['at-sender']:
msg.insert(
0,
Plain(" \n")
)
# 当回复的正文中包含换行时quote可能会自带at此时就不再单独添加at只添加换行
if "\n" not in str(msg[1]) or config['msg_source_adapter'] == 'nakuru':
msg.insert(
0,
At(
event.sender.id
)
# if "\n" not in str(msg[1]) or self.ap.platform_cfg.data['platform-adapter'] == 'nakuru':
msg.insert(
0,
At(
event.sender.id
)
)
await self.adapter.reply_message(
event,
msg,
quote_origin=True if config['quote_origin'] and check_quote else False
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
)
# 通知系统管理员
@ -161,19 +157,16 @@ class PlatformManager:
await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
async def notify_admin_message_chain(self, message: mirai.MessageChain):
config = self.ap.cfg_mgr.data
if config['admin_qq'] != 0 and config['admin_qq'] != []:
logging.info("通知管理员:{}".format(message))
if self.ap.system_cfg.data['admin-sessions'] != []:
admin_list = []
if type(config['admin_qq']) == int:
admin_list.append(config['admin_qq'])
for admin in self.ap.system_cfg.data['admin-sessions']:
admin_list.append(admin)
for adm in admin_list:
self.adapter.send_message(
"person",
adm,
adm.split("_")[0],
adm.split("_")[1],
message
)

View File

@ -22,8 +22,8 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
async def initialize(self):
self.client = openai.AsyncClient(
api_key="",
base_url=self.ap.cfg_mgr.data["openai_config"]["reverse_proxy"],
timeout=self.ap.cfg_mgr.data["process_message_timeout"],
base_url=self.ap.provider_cfg.data['openai-config']['base_url'],
timeout=self.ap.provider_cfg.data['openai-config']['request-timeout'],
)
async def _req(
@ -51,7 +51,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
args = self.ap.cfg_mgr.data["completion_api_params"].copy()
args = self.ap.provider_cfg.data['openai-config']['chat-completions-params'].copy()
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
if use_model.tool_call_supported:

View File

@ -29,7 +29,7 @@ class ModelManager:
async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values()))
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.provider_cfg.data['openai-config']['api-keys']))
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)

View File

@ -25,10 +25,15 @@ class SessionManager:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
return session
session_concurrency = self.ap.system_cfg.data['session-concurrency']['default']
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['session-concurrency']:
session_concurrency = self.ap.system_cfg.data['session-concurrency'][f'{query.launcher_type.value}_{query.launcher_id}']
session = core_entities.Session(
launcher_type=query.launcher_type,
launcher_id=query.launcher_id,
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000),
semaphore=asyncio.Semaphore(session_concurrency),
)
self.session_list.append(session)
return session
@ -41,7 +46,7 @@ class SessionManager:
conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.provider_cfg.data['openai-config']['chat-completions-params']['model']),
use_funcs=await self.ap.tool_mgr.get_all_functions(),
)
session.conversations.append(conversation)

View File

@ -14,8 +14,8 @@ class ScenarioPromptLoader(loader.PromptLoader):
async def load(self):
"""加载Prompt
"""
for file in os.listdir("scenarios"):
with open("scenarios/{}".format(file), "r", encoding="utf-8") as f:
for file in os.listdir("data/scenarios"):
with open("data/scenarios/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
file_json = json.loads(file_str)

View File

@ -14,7 +14,7 @@ class SingleSystemPromptLoader(loader.PromptLoader):
"""加载Prompt
"""
for name, cnt in self.ap.cfg_mgr.data['default_prompt'].items():
for name, cnt in self.ap.provider_cfg.data['prompt'].items():
prompt = entities.Prompt(
name=name,
messages=[
@ -26,8 +26,8 @@ class SingleSystemPromptLoader(loader.PromptLoader):
)
self.prompts.append(prompt)
for file in os.listdir("prompts"):
with open("prompts/{}".format(file), "r", encoding="utf-8") as f:
for file in os.listdir("data/prompts"):
with open("data/prompts/{}".format(file), "r", encoding="utf-8") as f:
file_str = f.read()
file_name = file.split(".")[0]
prompt = entities.Prompt(

View File

@ -23,7 +23,7 @@ class PromptManager:
"full_scenario": scenario.ScenarioPromptLoader
}
loader_cls = loader_map[self.ap.cfg_mgr.data['preset_mode']]
loader_cls = loader_map[self.ap.provider_cfg.data['prompt-mode']]
self.loader_inst: loader.PromptLoader = loader_cls(self.ap)

View File

@ -1,5 +1,8 @@
from __future__ import annotations
import os
import sys
from ..core import app
@ -14,17 +17,15 @@ class ProxyManager:
self.forward_proxies = {}
async def initialize(self):
config = self.ap.cfg_mgr.data
self.forward_proxies = {
"http": os.getenv("HTTP_PROXY") or os.getenv("http_proxy"),
"https": os.getenv("HTTPS_PROXY") or os.getenv("https_proxy"),
}
return (
{
"http": config["openai_config"]["proxy"],
"https": config["openai_config"]["proxy"],
}
if "proxy" in config["openai_config"]
and (config["openai_config"]["proxy"] is not None)
else None
)
if 'http' in self.ap.system_cfg.data['network-proxies']:
self.forward_proxies['http'] = self.ap.system_cfg.data['network-proxies']['http']
if 'https' in self.ap.system_cfg.data['network-proxies']:
self.forward_proxies['https'] = self.ap.system_cfg.data['network-proxies']['https']
def get_forward_proxies(self) -> str:
def get_forward_proxies(self) -> dict:
return self.forward_proxies

0
templates/__init__.py Normal file
View File

3
templates/command.json Normal file
View File

@ -0,0 +1,3 @@
{
"privilege": {}
}

36
templates/pipeline.json Normal file
View File

@ -0,0 +1,36 @@
{
"access-control":{
"mode": "blacklist",
"blacklist": [],
"whitelist": []
},
"respond-rules": {
"default": {
"at": true,
"prefix": [
"/ai", "!ai", "ai", "ai"
],
"regexp": [],
"random": 0.0
}
},
"income-msg-check": true,
"ignore-rules": {
"prefix": ["/"],
"regexp": []
},
"check-sensitive-words": true,
"baidu-cloud-examine": {
"enable": false,
"api-key": "",
"api-secret": ""
},
"submit-messages-tokens": 3072,
"rate-limit": {
"strategy": "drop",
"algo": "fixwin",
"fixwin": {
"default": 60
}
}
}

20
templates/platform.json Normal file
View File

@ -0,0 +1,20 @@
{
"platform-adapter": "yiri-mirai",
"yiri-mirai-config": {
"adapter": "WebSocketAdapter",
"host": "localhost",
"port": 8080,
"verifyKey": "yirimirai",
"qq": 123456789
},
"track-function-calls": true,
"quote-origin": false,
"at-sender": false,
"force-delay": [0, 0],
"long-text-process": {
"threshold": 256,
"strategy": "forward",
"font-path": ""
},
"hide-exception-info": true
}

View File

@ -0,0 +1,3 @@
{
"plugins": []
}

17
templates/provider.json Normal file
View File

@ -0,0 +1,17 @@
{
"enable-chat": true,
"openai-config": {
"api-keys": [
"sk-1234567890"
],
"base_url": "https://api.openai.com/v1",
"chat-completions-params": {
"model": "gpt-3.5-turbo"
},
"request-timeout": 120
},
"prompt-mode": "normal",
"prompt": {
"default": "如果用户之后想获取帮助,请你说”输入!help获取帮助“。"
}
}

View File

@ -0,0 +1,12 @@
{
"prompt": [
{
"role": "system",
"content": "You are a helpful assistant. 如果我需要帮助,你要说“输入!help获得帮助”"
},
{
"role": "assistant",
"content": "好的我是一个能干的AI助手。 如果你需要帮助,我会说“输入!help获得帮助”"
}
]
}

View File

@ -0,0 +1,78 @@
{
"说明": "mask将替换敏感词中的每一个字若mask_word值不为空则将敏感词整个替换为mask_word的值",
"mask": "*",
"mask_word": "",
"words": [
"习近平",
"胡锦涛",
"江泽民",
"温家宝",
"李克强",
"李长春",
"毛泽东",
"邓小平",
"周恩来",
"马克思",
"社会主义",
"共产党",
"共产主义",
"大陆官方",
"北京政权",
"中华帝国",
"中国政府",
"共狗",
"六四事件",
"天安门",
"六四",
"政治局常委",
"两会",
"共青团",
"学潮",
"八九",
"二十大",
"民进党",
"台独",
"台湾独立",
"台湾国",
"国民党",
"台湾民国",
"中华民国",
"pornhub",
"Pornhub",
"[Yy]ou[Pp]orn",
"porn",
"Porn",
"[Xx][Vv]ideos",
"[Rr]ed[Tt]ube",
"[Xx][Hh]amster",
"[Ss]pank[Ww]ire",
"[Ss]pank[Bb]ang",
"[Tt]ube8",
"[Yy]ou[Jj]izz",
"[Bb]razzers",
"[Nn]aughty[ ]?[Aa]merica",
"作爱",
"做爱",
"性交",
"性爱",
"自慰",
"阴茎",
"淫妇",
"肛交",
"交配",
"性关系",
"性活动",
"色情",
"色图",
"涩图",
"裸体",
"小穴",
"淫荡",
"性爱",
"翻墙",
"VPN",
"科学上网",
"挂梯子",
"GFW"
]
}

11
templates/system.json Normal file
View File

@ -0,0 +1,11 @@
{
"admin-sessions": [],
"network-proxies": {},
"report-usage": true,
"logging-level": "info",
"session-concurrency": {
"default": 1
},
"pipeline-concurrency": 20,
"help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接https://q.rkcn.top"
}