From 2a0cf573035ede31a592bea137dd4735592047b3 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jan 2024 00:16:42 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=91=BD=E4=BB=A4=E5=A4=84?= =?UTF-8?q?=E7=90=86=E5=9F=BA=E7=A1=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/__init__.py | 0 pkg/command/cmdmgr.py | 104 +++++++++++++++++++++++ pkg/command/entities.py | 43 ++++++++++ pkg/command/errors.py | 12 +++ pkg/command/operator.py | 71 ++++++++++++++++ pkg/command/operators/__init__.py | 0 pkg/command/operators/func.py | 23 +++++ pkg/core/app.py | 4 +- pkg/core/boot.py | 12 ++- pkg/pipeline/process/handlers/command.py | 39 ++++++--- 10 files changed, 286 insertions(+), 22 deletions(-) create mode 100644 pkg/command/__init__.py create mode 100644 pkg/command/cmdmgr.py create mode 100644 pkg/command/entities.py create mode 100644 pkg/command/errors.py create mode 100644 pkg/command/operator.py create mode 100644 pkg/command/operators/__init__.py create mode 100644 pkg/command/operators/func.py diff --git a/pkg/command/__init__.py b/pkg/command/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py new file mode 100644 index 0000000..ae183e9 --- /dev/null +++ b/pkg/command/cmdmgr.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import typing + +from ..core import app, entities as core_entities +from ..openai import entities as llm_entities +from ..openai.session import entities as session_entities +from . import entities, operator, errors + +from .operators import func + + +class CommandManager: + """命令管理器 + """ + + ap: app.Application + + cmd_list: list[operator.CommandOperator] + + def __init__(self, ap: app.Application): + self.ap = ap + + async def initialize(self): + # 实例化所有类 + self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators] + + # 设置所有类的子节点 + for cmd in self.cmd_list: + cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__] + + # 初始化所有类 + for cmd in self.cmd_list: + await cmd.initialize() + + async def _execute( + self, + context: entities.ExecuteContext, + operator_list: list[operator.CommandOperator], + operator: operator.CommandOperator = None + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行命令 + """ + found = False + if len(context.crt_params) > 0: + for operator in operator_list: + if context.crt_params[0] == operator.name \ + or context.crt_params[0] in operator.alias: + found = True + context.crt_command = context.params[0] + context.crt_params = context.params[1:] + + async for ret in self._execute( + context, + operator.children, + operator + ): + yield ret + + if not found: + if operator is None: + yield entities.CommandReturn( + error=errors.CommandNotFoundError(context.crt_command) + ) + else: + if operator.lowest_privilege > context.privilege: + yield entities.CommandReturn( + error=errors.CommandPrivilegeError(context.crt_command) + ) + else: + async for ret in operator.execute(context): + yield ret + + + async def execute( + self, + command_text: str, + query: core_entities.Query, + session: session_entities.Session + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行命令 + """ + + privilege = 1 + if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \ + or query.sender_id in self.ap.cfg_mgr['admin_qq']: + privilege = 2 + + ctx = entities.ExecuteContext( + query=query, + session=session, + command_text=command_text, + command='', + crt_command='', + params=command_text.split(' '), + crt_params=command_text.split(' '), + privilege=privilege + ) + + async for ret in self._execute( + ctx, + self.cmd_list + ): + yield ret diff --git a/pkg/command/entities.py b/pkg/command/entities.py new file mode 100644 index 0000000..7fba96e --- /dev/null +++ b/pkg/command/entities.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import typing + +import pydantic +import mirai + +from ..core import app, entities as core_entities +from ..openai.session import entities as session_entities +from . import errors, operator + + +class CommandReturn(pydantic.BaseModel): + + text: typing.Optional[str] + """文本 + """ + + image: typing.Optional[mirai.Image] + + error: typing.Optional[errors.CommandError]= None + + class Config: + arbitrary_types_allowed = True + + +class ExecuteContext(pydantic.BaseModel): + + query: core_entities.Query + + session: session_entities.Session + + command_text: str + + command: str + + crt_command: str + + params: list[str] + + crt_params: list[str] + + privilege: int diff --git a/pkg/command/errors.py b/pkg/command/errors.py new file mode 100644 index 0000000..42c5a8b --- /dev/null +++ b/pkg/command/errors.py @@ -0,0 +1,12 @@ + + +class CommandError(Exception): + pass + + +class CommandNotFoundError(CommandError): + pass + + +class CommandPrivilegeError(CommandError): + pass \ No newline at end of file diff --git a/pkg/command/operator.py b/pkg/command/operator.py new file mode 100644 index 0000000..319da55 --- /dev/null +++ b/pkg/command/operator.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import typing +import abc + +from ..core import app, entities as core_entities +from ..openai.session import entities as session_entities +from . import entities + + +preregistered_operators: list[typing.Type[CommandOperator]] = [] + + +def operator_class( + name: str, + alias: list[str], + help: str, + privilege: int=1, # 1为普通用户,2为管理员 + parent_class: typing.Type[CommandOperator] = None +) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: + def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]: + cls.name = name + cls.alias = alias + cls.help = help + cls.parent_class = parent_class + + preregistered_operators.append(cls) + + return cls + + return decorator + + +class CommandOperator(metaclass=abc.ABCMeta): + """命令算子 + """ + + ap: app.Application + + name: str + """名称,搜索到时若符合则使用""" + + alias: list[str] + """同name""" + + help: str + """此节点的帮助信息""" + + parent_class: typing.Type[CommandOperator] + """父节点类。标记以供管理器在初始化时编织父子关系。""" + + lowest_privilege: int = 0 + """最低权限。若权限低于此值,则不予执行。""" + + children: list[CommandOperator] + """子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点, + 若有匹配中的,转移到子节点中执行,若没有匹配中的或没有子节点,执行此节点。""" + + def __init__(self, ap: app.Application): + self.ap = ap + self.children = [] + + async def initialize(self): + pass + + @abc.abstractmethod + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + pass diff --git a/pkg/command/operators/__init__.py b/pkg/command/operators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py new file mode 100644 index 0000000..c888831 --- /dev/null +++ b/pkg/command/operators/func.py @@ -0,0 +1,23 @@ +from __future__ import annotations +from typing import AsyncGenerator + +from .. import operator, entities, cmdmgr +from ...plugin import host as plugin_host + + +@operator.operator_class(name="func", alias=[], help="查看所有以注册的内容函数") +class FuncOperator(operator.CommandOperator): + async def execute( + self, + context: entities.ExecuteContext + ) -> AsyncGenerator[entities.CommandReturn, None]: + reply_str = "当前已加载的内容函数: \n\n" + + index = 1 + for func in plugin_host.__callable_functions__: + reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description']) + index += 1 + + yield entities.CommandReturn( + text=reply_str + ) \ No newline at end of file diff --git a/pkg/core/app.py b/pkg/core/app.py index c9d06e1..f069395 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -12,6 +12,7 @@ from ..openai.tools import toolmgr as llm_tool_mgr from ..config import manager as config_mgr from ..database import manager as database_mgr from ..utils.center import v2 as center_mgr +from ..command import cmdmgr from ..plugin import host as plugin_host from . import pool, controller from ..pipeline import stagemgr @@ -22,6 +23,8 @@ class Application: llm_mgr: openai_mgr.OpenAIInteract = None + cmd_mgr: cmdmgr.CommandManager = None + sess_mgr: llm_session_mgr.SessionManager = None model_mgr: llm_model_mgr.ModelManager = None @@ -54,7 +57,6 @@ class Application: # 把现有的所有内容函数加到toolmgr里 for func in plugin_host.__callable_functions__: - print(func) self.tool_mgr.register_legacy_function( name=func['name'], description=func['description'], diff --git a/pkg/core/boot.py b/pkg/core/boot.py index c06cc6c..2b03a15 100644 --- a/pkg/core/boot.py +++ b/pkg/core/boot.py @@ -19,9 +19,8 @@ from ..openai.session import sessionmgr as llm_session_mgr from ..openai.requester import modelmgr as llm_model_mgr from ..openai.sysprompt import sysprompt as llm_prompt_mgr from ..openai.tools import toolmgr as llm_tool_mgr -from ..openai import dprompt as llm_dprompt from ..qqbot import manager as im_mgr -from ..qqbot.cmds import aamgr as im_cmd_aamgr +from ..command import cmdmgr from ..plugin import host as plugin_host from ..utils.center import v2 as center_v2 from ..utils import updater @@ -81,11 +80,6 @@ async def make_app() -> app.Application: if cfg_mgr.data['admin_qq'] == 0: qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq") - # TODO make it async - llm_dprompt.register_all() - im_cmd_aamgr.register_all() - im_cmd_aamgr.apply_privileges() - # 构建组建实例 ap = app.Application() ap.logger = qcg_logger @@ -116,6 +110,10 @@ async def make_app() -> app.Application: llm_mgr_inst = llm_mgr.OpenAIInteract(ap) ap.llm_mgr = llm_mgr_inst + cmd_mgr_inst = cmdmgr.CommandManager(ap) + await cmd_mgr_inst.initialize() + ap.cmd_mgr = cmd_mgr_inst + llm_model_mgr_inst = llm_model_mgr.ModelManager(ap) await llm_model_mgr_inst.initialize() ap.model_mgr = llm_model_mgr_inst diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index c5fecb6..cf3e074 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -16,20 +16,31 @@ class CommandHandler(handler.MessageHandler): ) -> typing.AsyncGenerator[entities.StageProcessResult, None]: """处理 """ - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain('CommandHandler') - ]) + session = await self.ap.sess_mgr.get_session(query) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) + command_text = str(query.message_chain).strip()[1:] - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain('The Second Message') - ]) + async for ret in self.ap.cmd_mgr.execute( + command_text=command_text, + query=query, + session=session + ): + if ret.error is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(str(ret.error)) + ]) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) \ No newline at end of file + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + if ret.text is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(ret.text) + ]) + + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + )