refactor: 命令处理基础

This commit is contained in:
RockChinQ 2024-01-28 00:16:42 +08:00
parent f10af09bd2
commit 2a0cf57303
10 changed files with 286 additions and 22 deletions

0
pkg/command/__init__.py Normal file
View File

104
pkg/command/cmdmgr.py Normal file
View File

@ -0,0 +1,104 @@
from __future__ import annotations
import typing
from ..core import app, entities as core_entities
from ..openai import entities as llm_entities
from ..openai.session import entities as session_entities
from . import entities, operator, errors
from .operators import func
class CommandManager:
"""命令管理器
"""
ap: app.Application
cmd_list: list[operator.CommandOperator]
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
# 实例化所有类
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
# 设置所有类的子节点
for cmd in self.cmd_list:
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
# 初始化所有类
for cmd in self.cmd_list:
await cmd.initialize()
async def _execute(
self,
context: entities.ExecuteContext,
operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
found = False
if len(context.crt_params) > 0:
for operator in operator_list:
if context.crt_params[0] == operator.name \
or context.crt_params[0] in operator.alias:
found = True
context.crt_command = context.params[0]
context.crt_params = context.params[1:]
async for ret in self._execute(
context,
operator.children,
operator
):
yield ret
if not found:
if operator is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_command)
)
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(
error=errors.CommandPrivilegeError(context.crt_command)
)
else:
async for ret in operator.execute(context):
yield ret
async def execute(
self,
command_text: str,
query: core_entities.Query,
session: session_entities.Session
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
privilege = 1
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
privilege = 2
ctx = entities.ExecuteContext(
query=query,
session=session,
command_text=command_text,
command='',
crt_command='',
params=command_text.split(' '),
crt_params=command_text.split(' '),
privilege=privilege
)
async for ret in self._execute(
ctx,
self.cmd_list
):
yield ret

43
pkg/command/entities.py Normal file
View File

@ -0,0 +1,43 @@
from __future__ import annotations
import typing
import pydantic
import mirai
from ..core import app, entities as core_entities
from ..openai.session import entities as session_entities
from . import errors, operator
class CommandReturn(pydantic.BaseModel):
text: typing.Optional[str]
"""文本
"""
image: typing.Optional[mirai.Image]
error: typing.Optional[errors.CommandError]= None
class Config:
arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel):
query: core_entities.Query
session: session_entities.Session
command_text: str
command: str
crt_command: str
params: list[str]
crt_params: list[str]
privilege: int

12
pkg/command/errors.py Normal file
View File

@ -0,0 +1,12 @@
class CommandError(Exception):
pass
class CommandNotFoundError(CommandError):
pass
class CommandPrivilegeError(CommandError):
pass

71
pkg/command/operator.py Normal file
View File

@ -0,0 +1,71 @@
from __future__ import annotations
import typing
import abc
from ..core import app, entities as core_entities
from ..openai.session import entities as session_entities
from . import entities
preregistered_operators: list[typing.Type[CommandOperator]] = []
def operator_class(
name: str,
alias: list[str],
help: str,
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
cls.name = name
cls.alias = alias
cls.help = help
cls.parent_class = parent_class
preregistered_operators.append(cls)
return cls
return decorator
class CommandOperator(metaclass=abc.ABCMeta):
"""命令算子
"""
ap: app.Application
name: str
"""名称,搜索到时若符合则使用"""
alias: list[str]
"""同name"""
help: str
"""此节点的帮助信息"""
parent_class: typing.Type[CommandOperator]
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
lowest_privilege: int = 0
"""最低权限。若权限低于此值,则不予执行。"""
children: list[CommandOperator]
"""子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点,
若有匹配中的转移到子节点中执行若没有匹配中的或没有子节点执行此节点"""
def __init__(self, ap: app.Application):
self.ap = ap
self.children = []
async def initialize(self):
pass
@abc.abstractmethod
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
pass

View File

View File

@ -0,0 +1,23 @@
from __future__ import annotations
from typing import AsyncGenerator
from .. import operator, entities, cmdmgr
from ...plugin import host as plugin_host
@operator.operator_class(name="func", alias=[], help="查看所有以注册的内容函数")
class FuncOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已加载的内容函数: \n\n"
index = 1
for func in plugin_host.__callable_functions__:
reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description'])
index += 1
yield entities.CommandReturn(
text=reply_str
)

View File

@ -12,6 +12,7 @@ from ..openai.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr
from ..database import manager as database_mgr
from ..utils.center import v2 as center_mgr
from ..command import cmdmgr
from ..plugin import host as plugin_host
from . import pool, controller
from ..pipeline import stagemgr
@ -22,6 +23,8 @@ class Application:
llm_mgr: openai_mgr.OpenAIInteract = None
cmd_mgr: cmdmgr.CommandManager = None
sess_mgr: llm_session_mgr.SessionManager = None
model_mgr: llm_model_mgr.ModelManager = None
@ -54,7 +57,6 @@ class Application:
# 把现有的所有内容函数加到toolmgr里
for func in plugin_host.__callable_functions__:
print(func)
self.tool_mgr.register_legacy_function(
name=func['name'],
description=func['description'],

View File

@ -19,9 +19,8 @@ from ..openai.session import sessionmgr as llm_session_mgr
from ..openai.requester import modelmgr as llm_model_mgr
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
from ..openai.tools import toolmgr as llm_tool_mgr
from ..openai import dprompt as llm_dprompt
from ..qqbot import manager as im_mgr
from ..qqbot.cmds import aamgr as im_cmd_aamgr
from ..command import cmdmgr
from ..plugin import host as plugin_host
from ..utils.center import v2 as center_v2
from ..utils import updater
@ -81,11 +80,6 @@ async def make_app() -> app.Application:
if cfg_mgr.data['admin_qq'] == 0:
qcg_logger.warning("未设置管理员QQ号将无法使用管理员命令请在 config.py 中修改 admin_qq")
# TODO make it async
llm_dprompt.register_all()
im_cmd_aamgr.register_all()
im_cmd_aamgr.apply_privileges()
# 构建组建实例
ap = app.Application()
ap.logger = qcg_logger
@ -116,6 +110,10 @@ async def make_app() -> app.Application:
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
ap.llm_mgr = llm_mgr_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize()
ap.cmd_mgr = cmd_mgr_inst
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
await llm_model_mgr_inst.initialize()
ap.model_mgr = llm_model_mgr_inst

View File

@ -16,20 +16,31 @@ class CommandHandler(handler.MessageHandler):
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
query.resp_message_chain = mirai.MessageChain([
mirai.Plain('CommandHandler')
])
session = await self.ap.sess_mgr.get_session(query)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
command_text = str(query.message_chain).strip()[1:]
query.resp_message_chain = mirai.MessageChain([
mirai.Plain('The Second Message')
])
async for ret in self.ap.cmd_mgr.execute(
command_text=command_text,
query=query,
session=session
):
if ret.error is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(str(ret.error))
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
if ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)