From 1368ee22b2c15ab5bcae98efca1e7c8104582a15 Mon Sep 17 00:00:00 2001 From: RockChinQ <1010553892@qq.com> Date: Sun, 28 Jan 2024 18:21:43 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=91=BD=E4=BB=A4=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/command/cmdmgr.py | 24 ++- pkg/command/errors.py | 27 ++- pkg/command/operator.py | 8 +- pkg/command/operators/cfg.py | 98 ++++++++++ pkg/command/operators/cmd.py | 50 +++++ pkg/command/operators/default.py | 62 ++++++ pkg/command/operators/delc.py | 62 ++++++ pkg/command/operators/func.py | 18 +- pkg/command/operators/help.py | 23 +++ pkg/command/operators/last.py | 36 ++++ pkg/command/operators/list.py | 56 ++++++ pkg/command/operators/next.py | 35 ++++ pkg/command/operators/plugin.py | 239 +++++++++++++++++++++++ pkg/command/operators/prompt.py | 29 +++ pkg/command/operators/resend.py | 34 ++++ pkg/command/operators/reset.py | 23 +++ pkg/command/operators/version.py | 28 +++ pkg/openai/requester/apis/chatcmpl.py | 2 +- pkg/openai/sysprompt/sysprompt.py | 9 +- pkg/pipeline/process/handlers/chat.py | 8 + pkg/pipeline/process/handlers/command.py | 22 ++- 21 files changed, 859 insertions(+), 34 deletions(-) create mode 100644 pkg/command/operators/cfg.py create mode 100644 pkg/command/operators/cmd.py create mode 100644 pkg/command/operators/default.py create mode 100644 pkg/command/operators/delc.py create mode 100644 pkg/command/operators/help.py create mode 100644 pkg/command/operators/last.py create mode 100644 pkg/command/operators/list.py create mode 100644 pkg/command/operators/next.py create mode 100644 pkg/command/operators/plugin.py create mode 100644 pkg/command/operators/prompt.py create mode 100644 pkg/command/operators/resend.py create mode 100644 pkg/command/operators/reset.py create mode 100644 pkg/command/operators/version.py diff --git a/pkg/command/cmdmgr.py b/pkg/command/cmdmgr.py index ae183e9..cff5969 100644 --- a/pkg/command/cmdmgr.py +++ b/pkg/command/cmdmgr.py @@ -7,7 +7,7 @@ from ..openai import entities as llm_entities from ..openai.session import entities as session_entities from . import entities, operator, errors -from .operators import func +from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version class CommandManager: @@ -41,31 +41,35 @@ class CommandManager: ) -> typing.AsyncGenerator[entities.CommandReturn, None]: """执行命令 """ + found = False if len(context.crt_params) > 0: - for operator in operator_list: - if context.crt_params[0] == operator.name \ - or context.crt_params[0] in operator.alias: + for oper in operator_list: + if (context.crt_params[0] == oper.name \ + or context.crt_params[0] in oper.alias) \ + and (oper.parent_class is None or oper.parent_class == operator.__class__): found = True - context.crt_command = context.params[0] - context.crt_params = context.params[1:] + + context.crt_command = context.crt_params[0] + context.crt_params = context.crt_params[1:] async for ret in self._execute( context, - operator.children, - operator + oper.children, + oper ): yield ret + break if not found: if operator is None: yield entities.CommandReturn( - error=errors.CommandNotFoundError(context.crt_command) + error=errors.CommandNotFoundError(context.crt_params[0]) ) else: if operator.lowest_privilege > context.privilege: yield entities.CommandReturn( - error=errors.CommandPrivilegeError(context.crt_command) + error=errors.CommandPrivilegeError(operator.name) ) else: async for ret in operator.execute(context): diff --git a/pkg/command/errors.py b/pkg/command/errors.py index 42c5a8b..5bc253f 100644 --- a/pkg/command/errors.py +++ b/pkg/command/errors.py @@ -1,12 +1,33 @@ class CommandError(Exception): - pass + + def __init__(self, message: str = None): + self.message = message + + def __str__(self): + return self.message class CommandNotFoundError(CommandError): - pass + + def __init__(self, message: str = None): + super().__init__("未知命令: "+message) class CommandPrivilegeError(CommandError): - pass \ No newline at end of file + + def __init__(self, message: str = None): + super().__init__("权限不足: "+message) + + +class ParamNotEnoughError(CommandError): + + def __init__(self, message: str = None): + super().__init__("参数不足: "+message) + + +class CommandOperationError(CommandError): + + def __init__(self, message: str = None): + super().__init__("操作失败: "+message) diff --git a/pkg/command/operator.py b/pkg/command/operator.py index 319da55..af1a5d6 100644 --- a/pkg/command/operator.py +++ b/pkg/command/operator.py @@ -13,8 +13,9 @@ preregistered_operators: list[typing.Type[CommandOperator]] = [] def operator_class( name: str, - alias: list[str], help: str, + usage: str = None, + alias: list[str] = [], privilege: int=1, # 1为普通用户,2为管理员 parent_class: typing.Type[CommandOperator] = None ) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: @@ -22,6 +23,7 @@ def operator_class( cls.name = name cls.alias = alias cls.help = help + cls.usage = usage cls.parent_class = parent_class preregistered_operators.append(cls) @@ -46,7 +48,9 @@ class CommandOperator(metaclass=abc.ABCMeta): help: str """此节点的帮助信息""" - parent_class: typing.Type[CommandOperator] + usage: str = None + + parent_class: typing.Type[CommandOperator] | None = None """父节点类。标记以供管理器在初始化时编织父子关系。""" lowest_privilege: int = 0 diff --git a/pkg/command/operators/cfg.py b/pkg/command/operators/cfg.py new file mode 100644 index 0000000..b67ff3e --- /dev/null +++ b/pkg/command/operators/cfg.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import typing +import json + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="cfg", + help="配置项管理", + usage='!cfg <配置项> [配置值]\n!cfg all', + privilege=2 +) +class CfgOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + reply = '' + + params = context.crt_params + cfg_mgr = self.ap.cfg_mgr + + false = False + true = True + + reply_str = "" + if len(params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供配置项名称')) + else: + cfg_name = params[0] + if cfg_name == 'all': + reply_str = "[bot]所有配置项:\n\n" + for cfg in cfg_mgr.data.keys(): + if not cfg.startswith('__') and not cfg == 'logging': + # 根据配置项类型进行格式化,如果是字典则转换为json并格式化 + if isinstance(cfg_mgr.data[cfg], str): + reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg]) + elif isinstance(cfg_mgr.data[cfg], dict): + # 不进行unicode转义,并格式化 + reply_str += "{}: {}\n".format(cfg, + json.dumps(cfg_mgr.data[cfg], + ensure_ascii=False, indent=4)) + else: + reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg]) + yield entities.CommandReturn(text=reply_str) + else: + cfg_entry_path = cfg_name.split('.') + + try: + if len(params) == 1: # 未指定配置值,返回配置项值 + cfg_entry = cfg_mgr.data[cfg_entry_path[0]] + if len(cfg_entry_path) > 1: + for i in range(1, len(cfg_entry_path)): + cfg_entry = cfg_entry[cfg_entry_path[i]] + + if isinstance(cfg_entry, str): + reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry) + elif isinstance(cfg_entry, dict): + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, + json.dumps(cfg_entry, + ensure_ascii=False, indent=4)) + else: + reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry) + + yield entities.CommandReturn(text=reply_str) + else: + cfg_value = " ".join(params[1:]) + + cfg_value = eval(cfg_value) + + cfg_entry = cfg_mgr.data[cfg_entry_path[0]] + if len(cfg_entry_path) > 1: + for i in range(1, len(cfg_entry_path) - 1): + cfg_entry = cfg_entry[cfg_entry_path[i]] + if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)): + cfg_entry[cfg_entry_path[-1]] = cfg_value + yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name)) + else: + # reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)] + yield entities.CommandReturn(error=errors.CommandOperationError("配置项{}类型不匹配".format(cfg_name))) + else: + cfg_mgr.data[cfg_entry_path[0]] = cfg_value + # reply = ["[bot]配置项{}修改成功".format(cfg_name)] + yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name)) + except KeyError: + # reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] + yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name))) + except NameError: + # reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)] + yield entities.CommandReturn(error=errors.CommandOperationError("值{}不合法(字符串需要使用双引号包裹)".format(cfg_value))) + except ValueError: + # reply = ["[bot]err:未找到配置项 {}".format(cfg_name)] + yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name))) diff --git a/pkg/command/operators/cmd.py b/pkg/command/operators/cmd.py new file mode 100644 index 0000000..17b5ed0 --- /dev/null +++ b/pkg/command/operators/cmd.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="cmd", + help='显示命令列表', + usage='!cmd\n!cmd <命令名称>' +) +class CmdOperator(operator.CommandOperator): + """命令列表 + """ + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + if len(context.crt_params) == 0: + reply_str = "当前所有命令: \n\n" + + for cmd in self.ap.cmd_mgr.cmd_list: + if cmd.parent_class is None: + reply_str += f"{cmd.name}: {cmd.help}\n" + + reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助" + + yield entities.CommandReturn(text=reply_str.strip()) + + else: + cmd_name = context.crt_params[0] + + cmd = None + + for _cmd in self.ap.cmd_mgr.cmd_list: + if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None): + cmd = _cmd + break + + if cmd is None: + yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) + else: + reply_str = f"{cmd.name}: {cmd.help}\n\n" + reply_str += f"使用方法: \n{cmd.usage}" + + yield entities.CommandReturn(text=reply_str.strip()) diff --git a/pkg/command/operators/default.py b/pkg/command/operators/default.py new file mode 100644 index 0000000..ca7e404 --- /dev/null +++ b/pkg/command/operators/default.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="default", + help="操作情景预设", + usage='!default\n!default set <指定情景预设为默认>' +) +class DefaultOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + reply_str = "当前所有情景预设: \n\n" + + for prompt in self.ap.prompt_mgr.get_all_prompts(): + + content = "" + for msg in prompt.messages: + content += f" {msg.role}: {msg.content}" + + reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n" + + reply_str += f"当前会话使用的是: {context.session.use_prompt_name}" + + yield entities.CommandReturn(text=reply_str.strip()) + + +@operator.operator_class( + name="set", + help="设置当前会话默认情景预设", + parent_class=DefaultOperator +) +class DefaultSetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称')) + else: + prompt_name = context.crt_params[0] + + try: + prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name) + if prompt is None: + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name))) + else: + context.session.use_prompt_name = prompt.name + yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e))) diff --git a/pkg/command/operators/delc.py b/pkg/command/operators/delc.py new file mode 100644 index 0000000..db865ff --- /dev/null +++ b/pkg/command/operators/delc.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="del", + help="删除当前会话的历史记录", + usage='!del <序号>\n!del all' +) +class DelOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + delete_index = 0 + if len(context.crt_params) > 0: + try: + delete_index = int(context.crt_params[0]) + except: + yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) + return + + if delete_index < 0 or delete_index >= len(context.session.conversations): + yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) + return + + # 倒序 + to_delete_index = len(context.session.conversations)-1-delete_index + + if context.session.conversations[to_delete_index] == context.session.using_conversation: + context.session.using_conversation = None + + del context.session.conversations[to_delete_index] + + yield entities.CommandReturn(text=f"已删除对话: {delete_index}") + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + + +@operator.operator_class( + name="all", + help="删除此会话的所有历史记录", + parent_class=DelOperator +) +class DelAllOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + context.session.conversations = [] + context.session.using_conversation = None + + yield entities.CommandReturn(text="已删除所有对话") \ No newline at end of file diff --git a/pkg/command/operators/func.py b/pkg/command/operators/func.py index c888831..a4e81c3 100644 --- a/pkg/command/operators/func.py +++ b/pkg/command/operators/func.py @@ -5,19 +5,21 @@ from .. import operator, entities, cmdmgr from ...plugin import host as plugin_host -@operator.operator_class(name="func", alias=[], help="查看所有以注册的内容函数") +@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') class FuncOperator(operator.CommandOperator): async def execute( - self, - context: entities.ExecuteContext + self, context: entities.ExecuteContext ) -> AsyncGenerator[entities.CommandReturn, None]: reply_str = "当前已加载的内容函数: \n\n" index = 1 - for func in plugin_host.__callable_functions__: - reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description']) + for func in self.ap.tool_mgr.all_functions: + reply_str += "{}. {}{}:\n{}\n\n".format( + index, + ("(已禁用) " if not func.enable else ""), + func.name, + func.description, + ) index += 1 - yield entities.CommandReturn( - text=reply_str - ) \ No newline at end of file + yield entities.CommandReturn(text=reply_str) diff --git a/pkg/command/operators/help.py b/pkg/command/operators/help.py new file mode 100644 index 0000000..c99c294 --- /dev/null +++ b/pkg/command/operators/help.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name='help', + help='显示帮助', + usage='!help\n!help <命令名称>' +) +class HelpOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + help = self.ap.tips_mgr.data['help_message'] + + help += '\n发送命令 !cmd 可查看命令列表' + + yield entities.CommandReturn(text=help) diff --git a/pkg/command/operators/last.py b/pkg/command/operators/last.py new file mode 100644 index 0000000..8e3a523 --- /dev/null +++ b/pkg/command/operators/last.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import typing +import datetime + + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="last", + help="切换到前一个对话", + usage='!last' +) +class LastOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + # 找到当前会话的上一个会话 + for index in range(len(context.session.conversations)-1, -1, -1): + if context.session.conversations[index] == context.session.using_conversation: + if index == 0: + yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) + return + else: + context.session.using_conversation = context.session.conversations[index-1] + time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + + yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + return + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/command/operators/list.py b/pkg/command/operators/list.py new file mode 100644 index 0000000..a91285e --- /dev/null +++ b/pkg/command/operators/list.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="list", + help="列出此会话中的所有历史对话", + usage='!list\n!list <页码>' +) +class ListOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + page = 0 + + if len(context.crt_params) > 0: + try: + page = int(context.crt_params[0]-1) + except: + yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) + return + + record_per_page = 10 + + content = '' + + index = 0 + + using_conv_index = 0 + + for conv in context.session.conversations[::-1]: + time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S") + + if conv == context.session.using_conversation: + using_conv_index = index + + if index >= page * record_per_page and index < (page + 1) * record_per_page: + content += f"{index} {time_str}: {conv.messages[0].content}\n" + index += 1 + + if content == '': + content = '无' + else: + if context.session.using_conversation is None: + content += "\n当前处于新会话" + else: + content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content}" + + yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}") diff --git a/pkg/command/operators/next.py b/pkg/command/operators/next.py new file mode 100644 index 0000000..8f4b5a5 --- /dev/null +++ b/pkg/command/operators/next.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import typing +import datetime + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="next", + help="切换到后一个对话", + usage='!next' +) +class NextOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if context.session.conversations: + # 找到当前会话的下一个会话 + for index in range(len(context.session.conversations)): + if context.session.conversations[index] == context.session.using_conversation: + if index == len(context.session.conversations)-1: + yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) + return + else: + context.session.using_conversation = context.session.conversations[index+1] + time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S") + + yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}") + return + else: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) \ No newline at end of file diff --git a/pkg/command/operators/plugin.py b/pkg/command/operators/plugin.py new file mode 100644 index 0000000..195852a --- /dev/null +++ b/pkg/command/operators/plugin.py @@ -0,0 +1,239 @@ +from __future__ import annotations +import typing +import traceback + +from .. import operator, entities, cmdmgr, errors +from ...plugin import host as plugin_host +from ...utils import updater +from ...core import app + + +@operator.operator_class( + name="plugin", + help="插件操作", + usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>" +) +class PluginOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + plugin_list = plugin_host.__plugins__ + reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__)) + idx = 0 + for key in plugin_host.iter_plugins_name(): + plugin = plugin_list[key] + reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ + .format((idx+1), plugin['name'], + "[已禁用]" if not plugin['enabled'] else "", + plugin['description'], + plugin['version'], plugin['author']) + + # TODO 从元数据调远程地址 + # if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): + # remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1])) + # if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT": + # reply_str += "源码: "+remote_url+"\n" + + idx += 1 + + yield entities.CommandReturn(text=reply_str) + + +@operator.operator_class( + name="get", + help="安装插件", + privilege=2, + parent_class=PluginOperator +) +class PluginGetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) + else: + repo = context.crt_params[0] + + yield entities.CommandReturn(text="正在安装插件...") + + try: + plugin_host.install_plugin(repo) + yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e))) + + +@operator.operator_class( + name="update", + help="更新插件", + privilege=2, + parent_class=PluginOperator +) +class PluginUpdateOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) + + if plugin_path_name is not None: + yield entities.CommandReturn(text="正在更新插件...") + plugin_host.update_plugin(plugin_name) + yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件") + else: + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件")) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + +@operator.operator_class( + name="all", + help="更新所有插件", + privilege=2, + parent_class=PluginUpdateOperator +) +class PluginUpdateAllOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + try: + plugins = [] + + for key in plugin_host.__plugins__: + plugins.append(key) + + if plugins: + yield entities.CommandReturn(text="正在更新插件...") + updated = [] + try: + for plugin_name in plugins: + plugin_host.update_plugin(plugin_name) + updated.append(plugin_name) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated))) + else: + yield entities.CommandReturn(text="没有可更新的插件") + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e))) + + +@operator.operator_class( + name="del", + help="删除插件", + privilege=2, + parent_class=PluginOperator +) +class PluginDelOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) + + if plugin_path_name is not None: + yield entities.CommandReturn(text="正在删除插件...") + plugin_host.uninstall_plugin(plugin_name) + yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件") + else: + yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件")) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e))) + + +def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application): + if plugin_name in plugin_host.__plugins__: + plugin_host.__plugins__[plugin_name]['enabled'] = new_status + + for func in ap.tool_mgr.all_functions: + if func.name.startswith(plugin_name+'-'): + func.enable = new_status + + return True + else: + return False + + +@operator.operator_class( + name="on", + help="启用插件", + privilege=2, + parent_class=PluginOperator +) +class PluginEnableOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + if update_plugin_status(plugin_name, True, self.ap): + yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name)) + else: + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) + + +@operator.operator_class( + name="off", + help="禁用插件", + privilege=2, + parent_class=PluginOperator +) +class PluginDisableOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + + if len(context.crt_params) == 0: + yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) + else: + plugin_name = context.crt_params[0] + + try: + if update_plugin_status(plugin_name, False, self.ap): + yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name)) + else: + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name))) + except Exception as e: + traceback.print_exc() + yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e))) diff --git a/pkg/command/operators/prompt.py b/pkg/command/operators/prompt.py new file mode 100644 index 0000000..29d688a --- /dev/null +++ b/pkg/command/operators/prompt.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="prompt", + help="查看当前对话的前文", + usage='!prompt' +) +class PromptOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + if context.session.using_conversation is None: + yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) + else: + reply_str = '当前对话所有内容:\n\n' + + for msg in context.session.using_conversation.messages: + reply_str += f"{msg.role}: {msg.content}\n" + + yield entities.CommandReturn(text=reply_str) \ No newline at end of file diff --git a/pkg/command/operators/resend.py b/pkg/command/operators/resend.py new file mode 100644 index 0000000..6d93041 --- /dev/null +++ b/pkg/command/operators/resend.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="resend", + help="重发当前会话的最后一条消息", + usage='!resend' +) +class ResendOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + # 回滚到最后一条用户message前 + if context.session.using_conversation is None: + yield entities.CommandReturn(error=errors.CommandError("当前没有对话")) + else: + conv_msg = context.session.using_conversation.messages + + # 倒序一直删到最后一条用户message + while len(conv_msg) > 0 and conv_msg[-1].role != 'user': + conv_msg.pop() + + if len(conv_msg) > 0: + # 删除最后一条用户message + conv_msg.pop() + + # 不重发了,提示用户已删除就行了 + yield entities.CommandReturn(text="已删除最后一次请求记录") diff --git a/pkg/command/operators/reset.py b/pkg/command/operators/reset.py new file mode 100644 index 0000000..5d1402a --- /dev/null +++ b/pkg/command/operators/reset.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing + +from .. import operator, entities, cmdmgr, errors + + +@operator.operator_class( + name="reset", + help="重置当前会话", + usage='!reset' +) +class ResetOperator(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + """执行 + """ + context.session.using_conversation = None + + yield entities.CommandReturn(text="已重置当前会话") diff --git a/pkg/command/operators/version.py b/pkg/command/operators/version.py new file mode 100644 index 0000000..c223580 --- /dev/null +++ b/pkg/command/operators/version.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import typing + +from .. import operator, cmdmgr, entities, errors +from ...utils import updater + + +@operator.operator_class( + name="version", + help="显示版本信息", + usage='!version' +) +class VersionCommand(operator.CommandOperator): + + async def execute( + self, + context: entities.ExecuteContext + ) -> typing.AsyncGenerator[entities.CommandReturn, None]: + reply_str = f"当前版本: \n{updater.get_current_version_info()}" + + try: + if updater.is_new_version_available(): + reply_str += "\n\n有新版本可用, 使用 !update 更新" + except: + pass + + yield entities.CommandReturn(text=reply_str.strip()) \ No newline at end of file diff --git a/pkg/openai/requester/apis/chatcmpl.py b/pkg/openai/requester/apis/chatcmpl.py index 24ff2d7..1e3da1a 100644 --- a/pkg/openai/requester/apis/chatcmpl.py +++ b/pkg/openai/requester/apis/chatcmpl.py @@ -102,7 +102,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester): m.dict(exclude_none=True) for m in conversation.prompt.messages ] + [m.dict(exclude_none=True) for m in conversation.messages] - req_messages.append({"role": "user", "content": str(query.message_chain)}) + # req_messages.append({"role": "user", "content": str(query.message_chain)}) msg = await self._closure(req_messages, conversation) diff --git a/pkg/openai/sysprompt/sysprompt.py b/pkg/openai/sysprompt/sysprompt.py index 050f663..5df28ee 100644 --- a/pkg/openai/sysprompt/sysprompt.py +++ b/pkg/openai/sysprompt/sysprompt.py @@ -35,9 +35,16 @@ class PromptManager: """ return self.loader_inst.get_prompts() - async def get_prompt(self, name: str) -> loader.entities.Prompt: + async def get_prompt(self, name: str) -> loader.entities.Prompt: """获取Prompt """ for prompt in self.get_all_prompts(): if prompt.name == name: return prompt + + async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt: + """通过前缀获取Prompt + """ + for prompt in self.get_all_prompts(): + if prompt.name.startswith(prefix): + return prompt diff --git a/pkg/pipeline/process/handlers/chat.py b/pkg/pipeline/process/handlers/chat.py index 629c2b1..889b3bb 100644 --- a/pkg/pipeline/process/handlers/chat.py +++ b/pkg/pipeline/process/handlers/chat.py @@ -7,6 +7,7 @@ import mirai from .. import handler from ... import entities from ....core import entities as core_entities +from ....openai import entities as llm_entities class ChatMessageHandler(handler.MessageHandler): @@ -25,6 +26,13 @@ class ChatMessageHandler(handler.MessageHandler): conversation = await self.ap.sess_mgr.get_conversation(session) + conversation.messages.append( + llm_entities.Message( + role="user", + content=str(query.message_chain) + ) + ) + async for result in conversation.use_model.requester.request(query, conversation): conversation.messages.append(result) diff --git a/pkg/pipeline/process/handlers/command.py b/pkg/pipeline/process/handlers/command.py index cf3e074..f836a2a 100644 --- a/pkg/pipeline/process/handlers/command.py +++ b/pkg/pipeline/process/handlers/command.py @@ -34,13 +34,17 @@ class CommandHandler(handler.MessageHandler): result_type=entities.ResultType.CONTINUE, new_query=query ) - else: - if ret.text is not None: - query.resp_message_chain = mirai.MessageChain([ - mirai.Plain(ret.text) - ]) + elif ret.text is not None: + query.resp_message_chain = mirai.MessageChain([ + mirai.Plain(ret.text) + ]) - yield entities.StageProcessResult( - result_type=entities.ResultType.CONTINUE, - new_query=query - ) + yield entities.StageProcessResult( + result_type=entities.ResultType.CONTINUE, + new_query=query + ) + else: + yield entities.StageProcessResult( + result_type=entities.ResultType.INTERRUPT, + new_query=query + )