refactor: 恢复所有审计API调用

This commit is contained in:
RockChinQ 2024-01-31 00:02:19 +08:00
parent c1c751a9ab
commit 32162afa65
11 changed files with 172 additions and 48 deletions

View File

@ -38,6 +38,7 @@ class APIGroup(metaclass=abc.ABCMeta):
url = self.prefix + path
data = json.dumps(data)
headers['Content-Type'] = 'application/json'
try:
async with aiohttp.ClientSession() as session:
async with session.request(
@ -49,7 +50,7 @@ class APIGroup(metaclass=abc.ABCMeta):
**kwargs
) as resp:
self.ap.logger.debug("data: %s", data)
self.ap.logger.debug("ret: %s", await resp.json())
self.ap.logger.debug("ret: %s", await resp.text())
except Exception as e:
self.ap.logger.debug(f'上报失败: {e}')

View File

@ -9,7 +9,7 @@ class V2MainDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage", ap)
super().__init__(prefix+"/main", ap)
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data

View File

@ -9,7 +9,7 @@ class V2PluginDataAPI(apigroup.APIGroup):
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage", ap)
super().__init__(prefix+"/plugin", ap)
async def do(self, *args, **kwargs):
config = self.ap.cfg_mgr.data

View File

@ -167,7 +167,7 @@ class PluginDelOperator(operator.CommandOperator):
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
async def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None:
for plugin in ap.plugin_mgr.plugins:
if plugin.plugin_name == plugin_name:
@ -176,6 +176,8 @@ def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application
for func in plugin.content_functions:
func.enable = new_status
await ap.plugin_mgr.setting.dump_container_setting(ap.plugin_mgr.plugins)
break
return True
@ -202,7 +204,7 @@ class PluginEnableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0]
try:
if update_plugin_status(plugin_name, True, self.ap):
if await update_plugin_status(plugin_name, True, self.ap):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
@ -230,7 +232,7 @@ class PluginDisableOperator(operator.CommandOperator):
plugin_name = context.crt_params[0]
try:
if update_plugin_status(plugin_name, False, self.ap):
if await update_plugin_status(plugin_name, False, self.ap):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))

View File

@ -35,7 +35,7 @@ async def make_app() -> app.Application:
print("以下文件不存在,已自动生成,请修改配置文件后重启:")
for file in generated_files:
print("-", file)
sys.exit(0)
missing_deps = await deps.check_deps()
@ -52,28 +52,24 @@ async def make_app() -> app.Application:
# 生成标识符
identifier.init()
cfg_mgr = await config.load_python_module_config(
"config.py",
"config-template.py"
)
cfg_mgr = await config.load_python_module_config("config.py", "config-template.py")
cfg = cfg_mgr.data
# 检查是否携带了 --override 或 -r 参数
if '--override' in sys.argv or '-r' in sys.argv:
if "--override" in sys.argv or "-r" in sys.argv:
use_override = True
if use_override:
overrided = await config.override_config_manager(cfg_mgr)
if overrided:
qcg_logger.info("以下配置项已使用 override.json 覆盖:" + ",".join(overrided))
tips_mgr = await config.load_python_module_config(
"tips.py",
"tips-custom-template.py"
"tips.py", "tips-custom-template.py"
)
# 检查管理员QQ号
if cfg_mgr.data['admin_qq'] == 0:
if cfg_mgr.data["admin_qq"] == 0:
qcg_logger.warning("未设置管理员QQ号将无法使用管理员命令请在 config.py 中修改 admin_qq")
# 构建组建实例
@ -85,50 +81,38 @@ async def make_app() -> app.Application:
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
try:
announcements = await ann_mgr.fetch_new()
for ann in announcements:
ap.logger.info(f'[公告] {ann.time}: {ann.content}')
except Exception as e:
ap.logger.warning(f'获取公告时出错: {e}')
ap.query_pool = pool.QueryPool()
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
try:
if await ap.ver_mgr.is_new_version_available():
ap.logger.info("有新版本可用,请使用 !update 命令更新")
except Exception as e:
ap.logger.warning(f"检查版本更新时出错: {e}")
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
center_v2_api = center_v2.V2CenterAPI(
ap,
basic_info={
"host_id": identifier.identifier['host_id'],
"instance_id": identifier.identifier['instance_id'],
"host_id": identifier.identifier["host_id"],
"instance_id": identifier.identifier["instance_id"],
"semantic_version": ver_mgr.get_current_version(),
"platform": sys.platform,
},
runtime_info={
"admin_id": "{}".format(cfg['admin_qq']),
"msg_source": cfg['msg_source_adapter'],
}
"admin_id": "{}".format(cfg["admin_qq"]),
"msg_source": cfg["msg_source_adapter"],
},
)
ap.ctr_mgr = center_v2_api
# 发送公告
ann_mgr = announce.AnnouncementManager(ap)
await ann_mgr.show_announcements()
ap.query_pool = pool.QueryPool()
await ap.ver_mgr.show_version_update()
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
@ -159,7 +143,7 @@ async def make_app() -> app.Application:
ctrl = controller.Controller(ap)
ap.ctrl = ctrl
await ap.initialize()
return ap

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import typing
import time
import mirai
@ -84,9 +85,16 @@ class ChatMessageHandler(handler.MessageHandler):
called_functions = []
text_length = 0
start_time = time.time()
async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result)
if result.content is not None:
text_length += len(result.content)
# 转换成可读消息
if result.role == 'assistant':
@ -172,3 +180,13 @@ class ChatMessageHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE,
new_query=query
)
await self.ap.ctr_mgr.usage.post_query_record(
session_type=session.launcher_type.value,
session_id=str(session.launcher_id),
query_ability_provider="QChatGPT.Chat",
usage=text_length,
model_name=conversation.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)

View File

@ -64,6 +64,15 @@ class PluginManager:
"""
await self.installer.install_plugin(plugin_source)
await self.ap.ctr_mgr.plugin.post_install_record(
{
"name": "unknown",
"remote": plugin_source,
"author": "unknown",
"version": "HEAD"
}
)
async def uninstall_plugin(
self,
plugin_name: str,
@ -72,6 +81,17 @@ class PluginManager:
"""
await self.installer.uninstall_plugin(plugin_name)
plugin_container = self.get_plugin_by_name(plugin_name)
await self.ap.ctr_mgr.plugin.post_remove_record(
{
"name": plugin_name,
"remote": plugin_container.plugin_source,
"author": plugin_container.plugin_author,
"version": plugin_container.plugin_version
}
)
async def update_plugin(
self,
plugin_name: str,
@ -80,6 +100,19 @@ class PluginManager:
"""更新插件
"""
await self.installer.update_plugin(plugin_name, plugin_source)
plugin_container = self.get_plugin_by_name(plugin_name)
await self.ap.ctr_mgr.plugin.post_update_record(
plugin={
"name": plugin_name,
"remote": plugin_container.plugin_source,
"author": plugin_container.plugin_author,
"version": plugin_container.plugin_version
},
old_version=plugin_container.plugin_version,
new_version="HEAD"
)
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件
@ -98,10 +131,14 @@ class PluginManager:
event=event
)
emitted_plugins: list[context.RuntimeContainer] = []
for plugin in self.plugins:
if plugin.enabled:
if event.__class__ in plugin.event_handlers:
emitted_plugins.append(plugin)
is_prevented_default_before_call = ctx.is_prevented_default()
try:
@ -126,4 +163,19 @@ class PluginManager:
self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}')
if emitted_plugins:
plugins_info: list[dict] = [
{
'name': plugin.plugin_name,
'remote': plugin.plugin_source,
'version': plugin.plugin_version,
'author': plugin.plugin_author
} for plugin in emitted_plugins
]
await self.ap.ctr_mgr.usage.post_event_record(
plugins=plugins_info,
event_name=event.__class__.__name__
)
return ctx

View File

@ -59,6 +59,25 @@ class SettingManager:
await self.settings.dump_config()
async def dump_container_setting(
self,
plugin_containers: list[context.RuntimeContainer]
):
"""保存插件容器设置
"""
for plugin in plugin_containers:
for ps in self.settings.data['plugins']:
if ps['name'] == plugin.plugin_name:
plugin_dict = plugin.to_setting_dict()
for key in plugin_dict:
ps[key] = plugin_dict[key]
break
await self.settings.dump_config()
async def record_installed_plugin_source(
self,
pkg_path: str,

View File

@ -84,3 +84,24 @@ class ToolManager:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()
return f'error occurred when executing function {name}: {e}'
finally:
plugin = None
for p in self.ap.plugin_mgr.plugins:
if function in p.content_functions:
plugin = p
break
if plugin is not None:
await self.ap.ctr_mgr.usage.post_function_record(
plugin={
'name': plugin.plugin_name,
'remote': plugin.plugin_source,
'version': plugin.plugin_version,
'author': plugin.plugin_author
},
function_name=function.name,
function_description=function.description,
)

View File

@ -104,3 +104,20 @@ class AnnouncementManager:
await self.write_saved(all)
return to_show
async def show_announcements(
self
):
"""显示公告"""
try:
announcements = await self.fetch_new()
for ann in announcements:
self.ap.logger.info(f'[公告] {ann.time}: {ann.content}')
if announcements:
await self.ap.ctr_mgr.main.post_announcement_showed(
ids=[item.id for item in announcements]
)
except Exception as e:
self.ap.logger.warning(f'获取公告时出错: {e}')

View File

@ -148,7 +148,7 @@ class VersionManager:
with open("current_tag", "w") as f:
f.write(current_tag)
self.ap.ctr_mgr.main.post_update_record(
await self.ap.ctr_mgr.main.post_update_record(
spent_seconds=int(time.time()-start_time),
infer_reason="update",
old_version=old_tag,
@ -224,3 +224,13 @@ class VersionManager:
return 0
async def show_version_update(
self
):
try:
if await self.ap.ver_mgr.is_new_version_available():
self.ap.logger.info("有新版本可用,请使用 !update 命令更新")
except Exception as e:
self.ap.logger.warning(f"检查版本更新时出错: {e}")