refactor: 命令基本完成

This commit is contained in:
RockChinQ 2024-01-28 18:21:43 +08:00
parent 2a0cf57303
commit 1368ee22b2
21 changed files with 859 additions and 34 deletions

View File

@ -7,7 +7,7 @@ from ..openai import entities as llm_entities
from ..openai.session import entities as session_entities
from . import entities, operator, errors
from .operators import func
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version
class CommandManager:
@ -41,31 +41,35 @@ class CommandManager:
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令
"""
found = False
if len(context.crt_params) > 0:
for operator in operator_list:
if context.crt_params[0] == operator.name \
or context.crt_params[0] in operator.alias:
for oper in operator_list:
if (context.crt_params[0] == oper.name \
or context.crt_params[0] in oper.alias) \
and (oper.parent_class is None or oper.parent_class == operator.__class__):
found = True
context.crt_command = context.params[0]
context.crt_params = context.params[1:]
context.crt_command = context.crt_params[0]
context.crt_params = context.crt_params[1:]
async for ret in self._execute(
context,
operator.children,
operator
oper.children,
oper
):
yield ret
break
if not found:
if operator is None:
yield entities.CommandReturn(
error=errors.CommandNotFoundError(context.crt_command)
error=errors.CommandNotFoundError(context.crt_params[0])
)
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(
error=errors.CommandPrivilegeError(context.crt_command)
error=errors.CommandPrivilegeError(operator.name)
)
else:
async for ret in operator.execute(context):

View File

@ -1,12 +1,33 @@
class CommandError(Exception):
pass
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
pass
def __init__(self, message: str = None):
super().__init__("未知命令: "+message)
class CommandPrivilegeError(CommandError):
pass
def __init__(self, message: str = None):
super().__init__("权限不足: "+message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__("参数不足: "+message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__("操作失败: "+message)

View File

@ -13,8 +13,9 @@ preregistered_operators: list[typing.Type[CommandOperator]] = []
def operator_class(
name: str,
alias: list[str],
help: str,
usage: str = None,
alias: list[str] = [],
privilege: int=1, # 1为普通用户2为管理员
parent_class: typing.Type[CommandOperator] = None
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
@ -22,6 +23,7 @@ def operator_class(
cls.name = name
cls.alias = alias
cls.help = help
cls.usage = usage
cls.parent_class = parent_class
preregistered_operators.append(cls)
@ -46,7 +48,9 @@ class CommandOperator(metaclass=abc.ABCMeta):
help: str
"""此节点的帮助信息"""
parent_class: typing.Type[CommandOperator]
usage: str = None
parent_class: typing.Type[CommandOperator] | None = None
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
lowest_privilege: int = 0

View File

@ -0,0 +1,98 @@
from __future__ import annotations
import typing
import json
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="cfg",
help="配置项管理",
usage='!cfg <配置项> [配置值]\n!cfg all',
privilege=2
)
class CfgOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
reply = ''
params = context.crt_params
cfg_mgr = self.ap.cfg_mgr
false = False
true = True
reply_str = ""
if len(params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供配置项名称'))
else:
cfg_name = params[0]
if cfg_name == 'all':
reply_str = "[bot]所有配置项:\n\n"
for cfg in cfg_mgr.data.keys():
if not cfg.startswith('__') and not cfg == 'logging':
# 根据配置项类型进行格式化如果是字典则转换为json并格式化
if isinstance(cfg_mgr.data[cfg], str):
reply_str += "{}: \"{}\"\n".format(cfg, cfg_mgr.data[cfg])
elif isinstance(cfg_mgr.data[cfg], dict):
# 不进行unicode转义并格式化
reply_str += "{}: {}\n".format(cfg,
json.dumps(cfg_mgr.data[cfg],
ensure_ascii=False, indent=4))
else:
reply_str += "{}: {}\n".format(cfg, cfg_mgr.data[cfg])
yield entities.CommandReturn(text=reply_str)
else:
cfg_entry_path = cfg_name.split('.')
try:
if len(params) == 1: # 未指定配置值,返回配置项值
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
if len(cfg_entry_path) > 1:
for i in range(1, len(cfg_entry_path)):
cfg_entry = cfg_entry[cfg_entry_path[i]]
if isinstance(cfg_entry, str):
reply_str = "[bot]配置项{}: \"{}\"\n".format(cfg_name, cfg_entry)
elif isinstance(cfg_entry, dict):
reply_str = "[bot]配置项{}: {}\n".format(cfg_name,
json.dumps(cfg_entry,
ensure_ascii=False, indent=4))
else:
reply_str = "[bot]配置项{}: {}\n".format(cfg_name, cfg_entry)
yield entities.CommandReturn(text=reply_str)
else:
cfg_value = " ".join(params[1:])
cfg_value = eval(cfg_value)
cfg_entry = cfg_mgr.data[cfg_entry_path[0]]
if len(cfg_entry_path) > 1:
for i in range(1, len(cfg_entry_path) - 1):
cfg_entry = cfg_entry[cfg_entry_path[i]]
if isinstance(cfg_entry[cfg_entry_path[-1]], type(cfg_value)):
cfg_entry[cfg_entry_path[-1]] = cfg_value
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
else:
# reply = ["[bot]err:配置项{}类型不匹配".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("配置项{}类型不匹配".format(cfg_name)))
else:
cfg_mgr.data[cfg_entry_path[0]] = cfg_value
# reply = ["[bot]配置项{}修改成功".format(cfg_name)]
yield entities.CommandReturn(text="配置项{}修改成功".format(cfg_name))
except KeyError:
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))
except NameError:
# reply = ["[bot]err:值{}不合法(字符串需要使用双引号包裹)".format(cfg_value)]
yield entities.CommandReturn(error=errors.CommandOperationError("{}不合法(字符串需要使用双引号包裹)".format(cfg_value)))
except ValueError:
# reply = ["[bot]err:未找到配置项 {}".format(cfg_name)]
yield entities.CommandReturn(error=errors.CommandOperationError("未找到配置项 {}".format(cfg_name)))

View File

@ -0,0 +1,50 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="cmd",
help='显示命令列表',
usage='!cmd\n!cmd <命令名称>'
)
class CmdOperator(operator.CommandOperator):
"""命令列表
"""
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
if len(context.crt_params) == 0:
reply_str = "当前所有命令: \n\n"
for cmd in self.ap.cmd_mgr.cmd_list:
if cmd.parent_class is None:
reply_str += f"{cmd.name}: {cmd.help}\n"
reply_str += "\n使用 !cmd <命令名称> 查看命令的详细帮助"
yield entities.CommandReturn(text=reply_str.strip())
else:
cmd_name = context.crt_params[0]
cmd = None
for _cmd in self.ap.cmd_mgr.cmd_list:
if (cmd_name == _cmd.name or cmd_name in _cmd.alias) and (_cmd.parent_class is None):
cmd = _cmd
break
if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
else:
reply_str = f"{cmd.name}: {cmd.help}\n\n"
reply_str += f"使用方法: \n{cmd.usage}"
yield entities.CommandReturn(text=reply_str.strip())

View File

@ -0,0 +1,62 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="default",
help="操作情景预设",
usage='!default\n!default set <指定情景预设为默认>'
)
class DefaultOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前所有情景预设: \n\n"
for prompt in self.ap.prompt_mgr.get_all_prompts():
content = ""
for msg in prompt.messages:
content += f" {msg.role}: {msg.content}"
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
reply_str += f"当前会话使用的是: {context.session.use_prompt_name}"
yield entities.CommandReturn(text=reply_str.strip())
@operator.operator_class(
name="set",
help="设置当前会话默认情景预设",
parent_class=DefaultOperator
)
class DefaultSetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
else:
prompt_name = context.crt_params[0]
try:
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
if prompt is None:
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
else:
context.session.use_prompt_name = prompt.name
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))

View File

@ -0,0 +1,62 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="del",
help="删除当前会话的历史记录",
usage='!del <序号>\n!del all'
)
class DelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
delete_index = 0
if len(context.crt_params) > 0:
try:
delete_index = int(context.crt_params[0])
except:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
return
if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
return
# 倒序
to_delete_index = len(context.session.conversations)-1-delete_index
if context.session.conversations[to_delete_index] == context.session.using_conversation:
context.session.using_conversation = None
del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f"已删除对话: {delete_index}")
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
@operator.operator_class(
name="all",
help="删除此会话的所有历史记录",
parent_class=DelOperator
)
class DelAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
context.session.conversations = []
context.session.using_conversation = None
yield entities.CommandReturn(text="已删除所有对话")

View File

@ -5,19 +5,21 @@ from .. import operator, entities, cmdmgr
from ...plugin import host as plugin_host
@operator.operator_class(name="func", alias=[], help="查看所有以注册的内容函数")
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
class FuncOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
self, context: entities.ExecuteContext
) -> AsyncGenerator[entities.CommandReturn, None]:
reply_str = "当前已加载的内容函数: \n\n"
index = 1
for func in plugin_host.__callable_functions__:
reply_str += "{}. {}{}:\n{}\n\n".format(index, ("(已禁用) " if not func['enabled'] else ""), func['name'], func['description'])
for func in self.ap.tool_mgr.all_functions:
reply_str += "{}. {}{}:\n{}\n\n".format(
index,
("(已禁用) " if not func.enable else ""),
func.name,
func.description,
)
index += 1
yield entities.CommandReturn(
text=reply_str
)
yield entities.CommandReturn(text=reply_str)

View File

@ -0,0 +1,23 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name='help',
help='显示帮助',
usage='!help\n!help <命令名称>'
)
class HelpOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
help = self.ap.tips_mgr.data['help_message']
help += '\n发送命令 !cmd 可查看命令列表'
yield entities.CommandReturn(text=help)

View File

@ -0,0 +1,36 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="last",
help="切换到前一个对话",
usage='!last'
)
class LastOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的上一个会话
for index in range(len(context.session.conversations)-1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation:
if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index-1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="list",
help="列出此会话中的所有历史对话",
usage='!list\n!list <页码>'
)
class ListOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
page = 0
if len(context.crt_params) > 0:
try:
page = int(context.crt_params[0]-1)
except:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
return
record_per_page = 10
content = ''
index = 0
using_conv_index = 0
for conv in context.session.conversations[::-1]:
time_str = conv.create_time.strftime("%Y-%m-%d %H:%M:%S")
if conv == context.session.using_conversation:
using_conv_index = index
if index >= page * record_per_page and index < (page + 1) * record_per_page:
content += f"{index} {time_str}: {conv.messages[0].content}\n"
index += 1
if content == '':
content = ''
else:
if context.session.using_conversation is None:
content += "\n当前处于新会话"
else:
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content}"
yield entities.CommandReturn(text=f"{page + 1} 页 (时间倒序):\n{content}")

View File

@ -0,0 +1,35 @@
from __future__ import annotations
import typing
import datetime
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="next",
help="切换到后一个对话",
usage='!next'
)
class NextOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if context.session.conversations:
# 找到当前会话的下一个会话
for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations)-1:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
return
else:
context.session.using_conversation = context.session.conversations[index+1]
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
yield entities.CommandReturn(text=f"已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
return
else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))

View File

@ -0,0 +1,239 @@
from __future__ import annotations
import typing
import traceback
from .. import operator, entities, cmdmgr, errors
from ...plugin import host as plugin_host
from ...utils import updater
from ...core import app
@operator.operator_class(
name="plugin",
help="插件操作",
usage="!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>"
)
class PluginOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = plugin_host.__plugins__
reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__))
idx = 0
for key in plugin_host.iter_plugins_name():
plugin = plugin_list[key]
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin['name'],
"[已禁用]" if not plugin['enabled'] else "",
plugin['description'],
plugin['version'], plugin['author'])
# TODO 从元数据调远程地址
# if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
# remote_url = updater.get_remote_url("/".join(plugin['path'].split('/')[:-1]))
# if remote_url != "https://github.com/RockChinQ/QChatGPT" and remote_url != "https://gitee.com/RockChin/QChatGPT":
# reply_str += "源码: "+remote_url+"\n"
idx += 1
yield entities.CommandReturn(text=reply_str)
@operator.operator_class(
name="get",
help="安装插件",
privilege=2,
parent_class=PluginOperator
)
class PluginGetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
else:
repo = context.crt_params[0]
yield entities.CommandReturn(text="正在安装插件...")
try:
plugin_host.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件安装失败: "+str(e)))
@operator.operator_class(
name="update",
help="更新插件",
privilege=2,
parent_class=PluginOperator
)
class PluginUpdateOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
if plugin_path_name is not None:
yield entities.CommandReturn(text="正在更新插件...")
plugin_host.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
@operator.operator_class(
name="all",
help="更新所有插件",
privilege=2,
parent_class=PluginUpdateOperator
)
class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try:
plugins = []
for key in plugin_host.__plugins__:
plugins.append(key)
if plugins:
yield entities.CommandReturn(text="正在更新插件...")
updated = []
try:
for plugin_name in plugins:
plugin_host.update_plugin(plugin_name)
updated.append(plugin_name)
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
yield entities.CommandReturn(text="已更新插件: {}".format(", ".join(updated)))
else:
yield entities.CommandReturn(text="没有可更新的插件")
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: "+str(e)))
@operator.operator_class(
name="del",
help="删除插件",
privilege=2,
parent_class=PluginOperator
)
class PluginDelOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name)
if plugin_path_name is not None:
yield entities.CommandReturn(text="正在删除插件...")
plugin_host.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: "+str(e)))
def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if plugin_name in plugin_host.__plugins__:
plugin_host.__plugins__[plugin_name]['enabled'] = new_status
for func in ap.tool_mgr.all_functions:
if func.name.startswith(plugin_name+'-'):
func.enable = new_status
return True
else:
return False
@operator.operator_class(
name="on",
help="启用插件",
privilege=2,
parent_class=PluginOperator
)
class PluginEnableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if update_plugin_status(plugin_name, True, self.ap):
yield entities.CommandReturn(text="已启用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))
@operator.operator_class(
name="off",
help="禁用插件",
privilege=2,
parent_class=PluginOperator
)
class PluginDisableOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
else:
plugin_name = context.crt_params[0]
try:
if update_plugin_status(plugin_name, False, self.ap):
yield entities.CommandReturn(text="已禁用插件: {}".format(plugin_name))
else:
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: 未找到插件 {}".format(plugin_name)))
except Exception as e:
traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError("插件状态修改失败: "+str(e)))

View File

@ -0,0 +1,29 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="prompt",
help="查看当前对话的前文",
usage='!prompt'
)
class PromptOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
else:
reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages:
reply_str += f"{msg.role}: {msg.content}\n"
yield entities.CommandReturn(text=reply_str)

View File

@ -0,0 +1,34 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="resend",
help="重发当前会话的最后一条消息",
usage='!resend'
)
class ResendOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
# 回滚到最后一条用户message前
if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError("当前没有对话"))
else:
conv_msg = context.session.using_conversation.messages
# 倒序一直删到最后一条用户message
while len(conv_msg) > 0 and conv_msg[-1].role != 'user':
conv_msg.pop()
if len(conv_msg) > 0:
# 删除最后一条用户message
conv_msg.pop()
# 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text="已删除最后一次请求记录")

View File

@ -0,0 +1,23 @@
from __future__ import annotations
import typing
from .. import operator, entities, cmdmgr, errors
@operator.operator_class(
name="reset",
help="重置当前会话",
usage='!reset'
)
class ResetOperator(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行
"""
context.session.using_conversation = None
yield entities.CommandReturn(text="已重置当前会话")

View File

@ -0,0 +1,28 @@
from __future__ import annotations
import typing
from .. import operator, cmdmgr, entities, errors
from ...utils import updater
@operator.operator_class(
name="version",
help="显示版本信息",
usage='!version'
)
class VersionCommand(operator.CommandOperator):
async def execute(
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{updater.get_current_version_info()}"
try:
if updater.is_new_version_available():
reply_str += "\n\n有新版本可用, 使用 !update 更新"
except:
pass
yield entities.CommandReturn(text=reply_str.strip())

View File

@ -102,7 +102,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
m.dict(exclude_none=True) for m in conversation.prompt.messages
] + [m.dict(exclude_none=True) for m in conversation.messages]
req_messages.append({"role": "user", "content": str(query.message_chain)})
# req_messages.append({"role": "user", "content": str(query.message_chain)})
msg = await self._closure(req_messages, conversation)

View File

@ -35,9 +35,16 @@ class PromptManager:
"""
return self.loader_inst.get_prompts()
async def get_prompt(self, name: str) -> loader.entities.Prompt:
async def get_prompt(self, name: str) -> loader.entities.Prompt:
"""获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name == name:
return prompt
async def get_prompt_by_prefix(self, prefix: str) -> loader.entities.Prompt:
"""通过前缀获取Prompt
"""
for prompt in self.get_all_prompts():
if prompt.name.startswith(prefix):
return prompt

View File

@ -7,6 +7,7 @@ import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....openai import entities as llm_entities
class ChatMessageHandler(handler.MessageHandler):
@ -25,6 +26,13 @@ class ChatMessageHandler(handler.MessageHandler):
conversation = await self.ap.sess_mgr.get_conversation(session)
conversation.messages.append(
llm_entities.Message(
role="user",
content=str(query.message_chain)
)
)
async for result in conversation.use_model.requester.request(query, conversation):
conversation.messages.append(result)

View File

@ -34,13 +34,17 @@ class CommandHandler(handler.MessageHandler):
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
if ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
elif ret.text is not None:
query.resp_message_chain = mirai.MessageChain([
mirai.Plain(ret.text)
])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)