mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
refactor: 命令处理基础
This commit is contained in:
parent
f10af09bd2
commit
2a0cf57303
0
pkg/command/__init__.py
Normal file
0
pkg/command/__init__.py
Normal file
104
pkg/command/cmdmgr.py
Normal file
104
pkg/command/cmdmgr.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..openai import entities as llm_entities
|
||||
from ..openai.session import entities as session_entities
|
||||
from . import entities, operator, errors
|
||||
|
||||
from .operators import func
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""命令管理器
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
cmd_list: list[operator.CommandOperator]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
# 实例化所有类
|
||||
self.cmd_list = [cls(self.ap) for cls in operator.preregistered_operators]
|
||||
|
||||
# 设置所有类的子节点
|
||||
for cmd in self.cmd_list:
|
||||
cmd.children = [child for child in self.cmd_list if child.parent_class == cmd.__class__]
|
||||
|
||||
# 初始化所有类
|
||||
for cmd in self.cmd_list:
|
||||
await cmd.initialize()
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
context: entities.ExecuteContext,
|
||||
operator_list: list[operator.CommandOperator],
|
||||
operator: operator.CommandOperator = None
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行命令
|
||||
"""
|
||||
found = False
|
||||
if len(context.crt_params) > 0:
|
||||
for operator in operator_list:
|
||||
if context.crt_params[0] == operator.name \
|
||||
or context.crt_params[0] in operator.alias:
|
||||
found = True
|
||||
context.crt_command = context.params[0]
|
||||
context.crt_params = context.params[1:]
|
||||
|
||||
async for ret in self._execute(
|
||||
context,
|
||||
operator.children,
|
||||
operator
|
||||
):
|
||||
yield ret
|
||||
|
||||
if not found:
|
||||
if operator is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandNotFoundError(context.crt_command)
|
||||
)
|
||||
else:
|
||||
if operator.lowest_privilege > context.privilege:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandPrivilegeError(context.crt_command)
|
||||
)
|
||||
else:
|
||||
async for ret in operator.execute(context):
|
||||
yield ret
|
||||
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command_text: str,
|
||||
query: core_entities.Query,
|
||||
session: session_entities.Session
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""执行命令
|
||||
"""
|
||||
|
||||
privilege = 1
|
||||
if query.sender_id == self.ap.cfg_mgr.data['admin_qq'] \
|
||||
or query.sender_id in self.ap.cfg_mgr['admin_qq']:
|
||||
privilege = 2
|
||||
|
||||
ctx = entities.ExecuteContext(
|
||||
query=query,
|
||||
session=session,
|
||||
command_text=command_text,
|
||||
command='',
|
||||
crt_command='',
|
||||
params=command_text.split(' '),
|
||||
crt_params=command_text.split(' '),
|
||||
privilege=privilege
|
||||
)
|
||||
|
||||
async for ret in self._execute(
|
||||
ctx,
|
||||
self.cmd_list
|
||||
):
|
||||
yield ret
|
43
pkg/command/entities.py
Normal file
43
pkg/command/entities.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic
|
||||
import mirai
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..openai.session import entities as session_entities
|
||||
from . import errors, operator
|
||||
|
||||
|
||||
class CommandReturn(pydantic.BaseModel):
|
||||
|
||||
text: typing.Optional[str]
|
||||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[mirai.Image]
|
||||
|
||||
error: typing.Optional[errors.CommandError]= None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ExecuteContext(pydantic.BaseModel):
|
||||
|
||||
query: core_entities.Query
|
||||
|
||||
session: session_entities.Session
|
||||
|
||||
command_text: str
|
||||
|
||||
command: str
|
||||
|
||||
crt_command: str
|
||||
|
||||
params: list[str]
|
||||
|
||||
crt_params: list[str]
|
||||
|
||||
privilege: int
|
12
pkg/command/errors.py
Normal file
12
pkg/command/errors.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
|
||||
|
||||
class CommandError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CommandNotFoundError(CommandError):
|
||||
pass
|
||||
|
||||
|
||||
class CommandPrivilegeError(CommandError):
|
||||
pass
|
71
pkg/command/operator.py
Normal file
71
pkg/command/operator.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from ..openai.session import entities as session_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
||||
|
||||
|
||||
def operator_class(
|
||||
name: str,
|
||||
alias: list[str],
|
||||
help: str,
|
||||
privilege: int=1, # 1为普通用户,2为管理员
|
||||
parent_class: typing.Type[CommandOperator] = None
|
||||
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
|
||||
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
|
||||
cls.name = name
|
||||
cls.alias = alias
|
||||
cls.help = help
|
||||
cls.parent_class = parent_class
|
||||
|
||||
preregistered_operators.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""命令算子
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
name: str
|
||||
"""名称,搜索到时若符合则使用"""
|
||||
|
||||
alias: list[str]
|
||||
"""同name"""
|
||||
|
||||
help: str
|
||||
"""此节点的帮助信息"""
|
||||
|
||||
parent_class: typing.Type[CommandOperator]
|
||||
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
|
||||
|
||||
lowest_privilege: int = 0
|
||||
"""最低权限。若权限低于此值,则不予执行。"""
|
||||
|
||||
children: list[CommandOperator]
|
||||
"""子节点。解析命令时,若节点有子节点,则以下一个参数去匹配子节点,
|
||||
若有匹配中的,转移到子节点中执行,若没有匹配中的或没有子节点,执行此节点。"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.children = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
pass
|
0
pkg/command/operators/__init__.py
Normal file
0
pkg/command/operators/__init__.py
Normal file
23
pkg/command/operators/func.py
Normal file
23
pkg/command/operators/func.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from __future__ import annotations
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .. import operator, entities, cmdmgr
|
||||
from ...plugin import host as plugin_host
|
||||
|
||||
|
||||
@operator.operator_class(name="func", alias=[], help="查看所有以注册的内容函数")
|
||||
class FuncOperator(operator.CommandOperator):
|
||||
async def execute(
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||
reply_str = "当前已加载的内容函数: \n\n"
|
||||
|
||||
index = 1
|
||||
for func in plugin_host.__callable_functions__:
|
||||
reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description'])
|
||||
index += 1
|
||||
|
||||
yield entities.CommandReturn(
|
||||
text=reply_str
|
||||
)
|
|
@ -12,6 +12,7 @@ from ..openai.tools import toolmgr as llm_tool_mgr
|
|||
from ..config import manager as config_mgr
|
||||
from ..database import manager as database_mgr
|
||||
from ..utils.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import host as plugin_host
|
||||
from . import pool, controller
|
||||
from ..pipeline import stagemgr
|
||||
|
@ -22,6 +23,8 @@ class Application:
|
|||
|
||||
llm_mgr: openai_mgr.OpenAIInteract = None
|
||||
|
||||
cmd_mgr: cmdmgr.CommandManager = None
|
||||
|
||||
sess_mgr: llm_session_mgr.SessionManager = None
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
@ -54,7 +57,6 @@ class Application:
|
|||
|
||||
# 把现有的所有内容函数加到toolmgr里
|
||||
for func in plugin_host.__callable_functions__:
|
||||
print(func)
|
||||
self.tool_mgr.register_legacy_function(
|
||||
name=func['name'],
|
||||
description=func['description'],
|
||||
|
|
|
@ -19,9 +19,8 @@ from ..openai.session import sessionmgr as llm_session_mgr
|
|||
from ..openai.requester import modelmgr as llm_model_mgr
|
||||
from ..openai.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..openai.tools import toolmgr as llm_tool_mgr
|
||||
from ..openai import dprompt as llm_dprompt
|
||||
from ..qqbot import manager as im_mgr
|
||||
from ..qqbot.cmds import aamgr as im_cmd_aamgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import host as plugin_host
|
||||
from ..utils.center import v2 as center_v2
|
||||
from ..utils import updater
|
||||
|
@ -81,11 +80,6 @@ async def make_app() -> app.Application:
|
|||
if cfg_mgr.data['admin_qq'] == 0:
|
||||
qcg_logger.warning("未设置管理员QQ号,将无法使用管理员命令,请在 config.py 中修改 admin_qq")
|
||||
|
||||
# TODO make it async
|
||||
llm_dprompt.register_all()
|
||||
im_cmd_aamgr.register_all()
|
||||
im_cmd_aamgr.apply_privileges()
|
||||
|
||||
# 构建组建实例
|
||||
ap = app.Application()
|
||||
ap.logger = qcg_logger
|
||||
|
@ -116,6 +110,10 @@ async def make_app() -> app.Application:
|
|||
llm_mgr_inst = llm_mgr.OpenAIInteract(ap)
|
||||
ap.llm_mgr = llm_mgr_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
|
|
|
@ -16,20 +16,31 @@ class CommandHandler(handler.MessageHandler):
|
|||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain('CommandHandler')
|
||||
])
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain('The Second Message')
|
||||
])
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text,
|
||||
query=query,
|
||||
session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(str(ret.error))
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if ret.text is not None:
|
||||
query.resp_message_chain = mirai.MessageChain([
|
||||
mirai.Plain(ret.text)
|
||||
])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user