QChatGPT/pkg/command/cmdmgr.py

130 lines
4.3 KiB
Python
Raw Normal View History

2024-01-28 00:16:42 +08:00
from __future__ import annotations
import typing
from ..core import app, entities as core_entities
2024-01-28 19:20:10 +08:00
from ..provider import entities as llm_entities
2024-01-28 00:16:42 +08:00
from . import entities, operator, errors
2024-02-06 23:57:21 +08:00
from ..config import manager as cfg_mgr
2024-01-28 00:16:42 +08:00
2024-03-03 16:34:59 +08:00
# 引入所有算子以便注册
2024-02-06 21:26:03 +08:00
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
2024-01-28 00:16:42 +08:00
class CommandManager:
"""命令管理器
"""
ap: app.Application
cmd_list: list[operator.CommandOperator]
2024-03-03 16:34:59 +08:00
"""
运行时命令列表扁平存储各个对象包含对应的子节点引用
"""
2024-01-28 00:16:42 +08:00
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
2024-02-06 23:57:21 +08:00
# 设置各个类的路径
def set_path(cls: operator.CommandOperator, ancestors: list[str]):
cls.path = '.'.join(ancestors + [cls.name])
for op in operator.preregistered_operators:
if op.parent_class == cls:
set_path(op, ancestors + [cls.name])
for cls in operator.preregistered_operators:
if cls.parent_class is None:
set_path(cls, [])
# 应用命令权限配置
for cls in operator.preregistered_operators:
if cls.path in self.ap.command_cfg.data['privilege']:
cls.lowest_privilege = self.ap.command_cfg.data['privilege'][cls.path]
2024-01-28 00:16:42 +08:00
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
# 初始化所有类
for cmd in self.cmd_list:
await cmd.initialize()
async def _execute(
self,
context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
2024-01-28 18:21:43 +08:00
2024-01-28 00:16:42 +08:00
found = False
2024-03-03 16:34:59 +08:00
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
2024-01-28 18:21:43 +08:00
for oper in operator_list:
if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \
and (oper.parent_class is None or oper.parent_class == operator.__class__):
2024-01-28 00:16:42 +08:00
found = True
2024-01-28 18:21:43 +08:00
context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:]
2024-01-28 00:16:42 +08:00
async for ret in self._execute(
context,
2024-01-28 18:21:43 +08:00
oper.children,
oper
2024-01-28 00:16:42 +08:00
):
yield ret
2024-01-28 18:21:43 +08:00
break
2024-01-28 00:16:42 +08:00
2024-03-03 16:34:59 +08:00
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
2024-01-28 00:16:42 +08:00
if operator is None:
yield entities.CommandReturn(
2024-01-28 18:21:43 +08:00
error=errors.CommandNotFoundError(context.crt_params[0])
2024-01-28 00:16:42 +08:00
)
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(
2024-01-28 18:21:43 +08:00
error=errors.CommandPrivilegeError(operator.name)
2024-01-28 00:16:42 +08:00
)
else:
async for ret in operator.execute(context):
yield ret
async def execute(
self,
command_text: str,
query: core_entities.Query,
2024-01-29 21:22:27 +08:00
session: core_entities.Session
2024-01-28 00:16:42 +08:00
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
privilege = 1
2024-02-06 21:26:03 +08:00
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
2024-01-28 00:16:42 +08:00
privilege = 2
ctx = entities.ExecuteContext(
query=query,
session=session,
command_text=command_text,
command='',
crt_command='',
params=command_text.split(' '),
crt_params=command_text.split(' '),
privilege=privilege
)
async for ret in self._execute(
ctx,
self.cmd_list
):
yield ret