mirror of
https://github.com/RockChinQ/QChatGPT.git
synced 2024-11-16 03:32:33 +08:00
commit
995d1f61d2
|
@ -20,6 +20,8 @@ class CommandReturn(pydantic.BaseModel):
|
|||
image: typing.Optional[mirai.Image]
|
||||
|
||||
error: typing.Optional[errors.CommandError]= None
|
||||
"""错误
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
@ -30,17 +32,40 @@ class ExecuteContext(pydantic.BaseModel):
|
|||
"""
|
||||
|
||||
query: core_entities.Query
|
||||
"""本次消息的请求对象"""
|
||||
|
||||
session: core_entities.Session
|
||||
"""本次消息所属的会话对象"""
|
||||
|
||||
command_text: str
|
||||
"""命令完整文本"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
|
||||
crt_command: str
|
||||
"""当前命令
|
||||
|
||||
多级命令中crt_command为当前命令,command为根命令。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,command为plugin,crt_command为plugin
|
||||
处理到on时,command为plugin,crt_command为on
|
||||
"""
|
||||
|
||||
params: list[str]
|
||||
"""命令参数
|
||||
|
||||
整个命令以空格分割后的参数列表
|
||||
"""
|
||||
|
||||
crt_params: list[str]
|
||||
"""当前命令参数
|
||||
|
||||
多级命令中crt_params为当前命令参数,params为根命令参数。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr']
|
||||
处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr']
|
||||
"""
|
||||
|
||||
privilege: int
|
||||
"""发起人权限"""
|
||||
|
|
|
@ -52,6 +52,11 @@ def operator_class(
|
|||
|
||||
class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""命令算子抽象类
|
||||
|
||||
以下的参数均不需要在子类中设置,只需要在使用装饰器注册类时作为参数传递即可。
|
||||
命令支持级联,即一个命令可以有多个子命令,子命令可以有子命令,以此类推。
|
||||
处理命令时,若有子命令,会以当前参数列表的第一个参数去匹配子命令,若匹配成功,则转移到子命令中执行。
|
||||
若没有匹配成功或没有子命令,则执行当前命令。
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
|
@ -60,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
|||
"""名称,搜索到时若符合则使用"""
|
||||
|
||||
path: str
|
||||
"""路径,所有父节点的name的连接,用于定义命令权限"""
|
||||
"""路径,所有父节点的name的连接,用于定义命令权限,由管理器在初始化时自动设置。
|
||||
"""
|
||||
|
||||
alias: list[str]
|
||||
"""同name"""
|
||||
|
@ -69,6 +75,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
|||
"""此节点的帮助信息"""
|
||||
|
||||
usage: str = None
|
||||
"""用法"""
|
||||
|
||||
parent_class: typing.Union[typing.Type[CommandOperator], None] = None
|
||||
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
|
||||
|
@ -92,4 +99,15 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
|||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""实现此方法以执行命令
|
||||
|
||||
支持多次yield以返回多个结果。
|
||||
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
|
||||
|
||||
Args:
|
||||
context (entities.ExecuteContext): 命令执行上下文
|
||||
|
||||
Yields:
|
||||
entities.CommandReturn: 命令返回封装
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -8,15 +8,12 @@ from .. import model as file_model
|
|||
class JSONConfigFile(file_model.ConfigFile):
|
||||
"""JSON配置文件"""
|
||||
|
||||
config_file_name: str = None
|
||||
"""配置文件名"""
|
||||
|
||||
template_file_name: str = None
|
||||
"""模板文件名"""
|
||||
|
||||
def __init__(self, config_file_name: str, template_file_name: str) -> None:
|
||||
def __init__(
|
||||
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
|
||||
) -> None:
|
||||
self.config_file_name = config_file_name
|
||||
self.template_file_name = template_file_name
|
||||
self.template_data = template_data
|
||||
|
||||
def exists(self) -> bool:
|
||||
return os.path.exists(self.config_file_name)
|
||||
|
@ -29,23 +26,24 @@ class JSONConfigFile(file_model.ConfigFile):
|
|||
if not self.exists():
|
||||
await self.create()
|
||||
|
||||
with open(self.config_file_name, 'r', encoding='utf-8') as f:
|
||||
cfg = json.load(f)
|
||||
if self.template_file_name is not None:
|
||||
with open(self.config_file_name, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
|
||||
# 从模板文件中进行补全
|
||||
with open(self.template_file_name, 'r', encoding='utf-8') as f:
|
||||
template_cfg = json.load(f)
|
||||
with open(self.template_file_name, "r", encoding="utf-8") as f:
|
||||
self.template_data = json.load(f)
|
||||
|
||||
for key in template_cfg:
|
||||
for key in self.template_data:
|
||||
if key not in cfg:
|
||||
cfg[key] = template_cfg[key]
|
||||
cfg[key] = self.template_data[key]
|
||||
|
||||
return cfg
|
||||
|
||||
|
||||
async def save(self, cfg: dict):
|
||||
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def save_sync(self, cfg: dict):
|
||||
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
|
|
|
@ -43,11 +43,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
|
|||
return cfg_mgr
|
||||
|
||||
|
||||
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
|
||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
|
||||
"""加载JSON配置文件"""
|
||||
cfg_inst = json_file.JSONConfigFile(
|
||||
config_name,
|
||||
template_name
|
||||
template_name,
|
||||
template_data
|
||||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
|
|
|
@ -7,6 +7,7 @@ from ..core import app
|
|||
|
||||
|
||||
preregistered_migrations: list[typing.Type[Migration]] = []
|
||||
"""当前阶段暂不支持扩展"""
|
||||
|
||||
def migration_class(name: str, number: int):
|
||||
"""注册一个迁移
|
||||
|
|
|
@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
|||
template_file_name: str = None
|
||||
"""模板文件名"""
|
||||
|
||||
template_data: dict = None
|
||||
"""模板数据"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self) -> bool:
|
||||
pass
|
||||
|
|
|
@ -21,7 +21,7 @@ from ..utils import version as version_mgr, proxy as proxy_mgr
|
|||
class Application:
|
||||
"""运行时应用对象和上下文"""
|
||||
|
||||
im_mgr: im_mgr.PlatformManager = None
|
||||
platform_mgr: im_mgr.PlatformManager = None
|
||||
|
||||
cmd_mgr: cmdmgr.CommandManager = None
|
||||
|
||||
|
@ -85,10 +85,9 @@ class Application:
|
|||
tasks = []
|
||||
|
||||
try:
|
||||
|
||||
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.im_mgr.run()),
|
||||
asyncio.create_task(self.platform_mgr.run()),
|
||||
asyncio.create_task(self.ctrl.run())
|
||||
]
|
||||
|
||||
|
|
|
@ -10,7 +10,6 @@ required_deps = {
|
|||
"botpy": "qq-botpy",
|
||||
"PIL": "pillow",
|
||||
"nakuru": "nakuru-project-idk",
|
||||
"CallingGPT": "CallingGPT",
|
||||
"tiktoken": "tiktoken",
|
||||
"yaml": "pyyaml",
|
||||
"aiohttp": "aiohttp",
|
||||
|
|
|
@ -32,43 +32,43 @@ class Query(pydantic.BaseModel):
|
|||
"""请求ID,添加进请求池时生成"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型,platform设置"""
|
||||
"""会话类型,platform处理阶段设置"""
|
||||
|
||||
launcher_id: int
|
||||
"""会话ID,platform设置"""
|
||||
"""会话ID,platform处理阶段设置"""
|
||||
|
||||
sender_id: int
|
||||
"""发送者ID,platform设置"""
|
||||
"""发送者ID,platform处理阶段设置"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
"""事件,platform收到的事件"""
|
||||
"""事件,platform收到的原始事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链,platform收到的消息链"""
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter
|
||||
"""适配器对象"""
|
||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
"""会话对象,由前置处理器设置"""
|
||||
"""会话对象,由前置处理器阶段设置"""
|
||||
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器设置"""
|
||||
"""历史消息列表,由前置处理器阶段设置"""
|
||||
|
||||
prompt: typing.Optional[sysprompt_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器设置"""
|
||||
"""情景预设内容,由前置处理器阶段设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
"""此次请求的用户消息对象,由前置处理器设置"""
|
||||
"""此次请求的用户消息对象,由前置处理器阶段设置"""
|
||||
|
||||
use_model: typing.Optional[entities.LLMModelInfo] = None
|
||||
"""使用的模型,由前置处理器设置"""
|
||||
"""使用的模型,由前置处理器阶段设置"""
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器设置"""
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""由provider生成的回复消息对象列表"""
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
|
|
@ -7,7 +7,10 @@ from . import app
|
|||
|
||||
|
||||
preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
|
||||
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。"""
|
||||
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。
|
||||
|
||||
当前阶段暂不支持扩展
|
||||
"""
|
||||
|
||||
def stage_class(
|
||||
name: str
|
||||
|
|
|
@ -86,7 +86,7 @@ class BuildAppStage(stage.BootingStage):
|
|||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
|
|
|
@ -31,15 +31,24 @@ class EnableStage(enum.Enum):
|
|||
|
||||
class FilterResult(pydantic.BaseModel):
|
||||
level: ResultLevel
|
||||
"""结果等级
|
||||
|
||||
对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。
|
||||
对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。
|
||||
"""
|
||||
|
||||
replacement: str
|
||||
"""替换后的消息"""
|
||||
"""替换后的消息
|
||||
|
||||
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
|
||||
若没有修改内容,也需要返回原消息。
|
||||
"""
|
||||
|
||||
user_notice: str
|
||||
"""不通过时,用户提示消息"""
|
||||
"""不通过时,若此值不为空,将对用户提示消息"""
|
||||
|
||||
console_notice: str
|
||||
"""不通过时,控制台提示消息"""
|
||||
"""不通过时,若此值不为空,将在控制台提示消息"""
|
||||
|
||||
|
||||
class ManagerResultLevel(enum.Enum):
|
||||
|
|
|
@ -46,6 +46,11 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
|||
@property
|
||||
def enable_stages(self):
|
||||
"""启用的阶段
|
||||
|
||||
默认为消息请求AI前后的两个阶段。
|
||||
|
||||
entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。
|
||||
entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。
|
||||
"""
|
||||
return [
|
||||
entities.EnableStage.PRE,
|
||||
|
@ -60,5 +65,14 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
|||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
|
||||
|
||||
Args:
|
||||
message (str): 需要检查的内容
|
||||
|
||||
Returns:
|
||||
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -68,7 +68,7 @@ class Controller:
|
|||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
|
|
|
@ -15,6 +15,15 @@ preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
|||
def strategy_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
|
||||
"""长文本处理策略类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 策略名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
|
||||
assert issubclass(cls, LongTextStrategy)
|
||||
|
||||
|
@ -43,4 +52,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
|||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
||||
Args:
|
||||
message (str): 消息
|
||||
query (core_entities.Query): 此次请求的上下文对象
|
||||
|
||||
Returns:
|
||||
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
|
||||
"""
|
||||
return []
|
||||
|
|
|
@ -39,7 +39,14 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
mc = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='plugin',
|
||||
content=str(mc),
|
||||
)
|
||||
)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
|
|
@ -18,6 +18,7 @@ def algo_class(name: str):
|
|||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
"""限流算法抽象类"""
|
||||
|
||||
name: str = None
|
||||
|
||||
|
@ -31,9 +32,27 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
|||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
"""进入处理流程
|
||||
|
||||
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许进入处理流程,若返回false,则直接丢弃该请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, launcher_type: str, launcher_id: int):
|
||||
"""退出处理流程
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
|
@ -29,7 +29,7 @@ class SendResponseBackStage(stage.PipelineStage):
|
|||
|
||||
await asyncio.sleep(random_delay)
|
||||
|
||||
await self.ap.im_mgr.send(
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain,
|
||||
adapter=query.adapter
|
||||
|
|
|
@ -29,6 +29,13 @@ class ResponseWrapper(stage.PipelineStage):
|
|||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
|
|
@ -14,6 +14,14 @@ preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
|
|||
def adapter_class(
|
||||
name: str
|
||||
):
|
||||
"""消息平台适配器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 适配器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]:
|
||||
cls.name = name
|
||||
preregistered_adapters.append(cls)
|
||||
|
@ -27,12 +35,19 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
|||
name: str
|
||||
|
||||
bot_account_id: int
|
||||
"""机器人账号ID,需要在初始化时设置"""
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
config (dict): 对应的配置
|
||||
ap (app.Application): 应用上下文
|
||||
"""
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ class PlatformManager:
|
|||
if len(self.adapters) == 0:
|
||||
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
|
||||
|
||||
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
|
||||
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
|
||||
|
||||
|
|
|
@ -9,10 +9,86 @@ from ..provider.tools import entities as tools_entities
|
|||
from ..core import app
|
||||
|
||||
|
||||
def register(
|
||||
name: str,
|
||||
description: str,
|
||||
version: str,
|
||||
author: str
|
||||
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
|
||||
"""注册插件类
|
||||
|
||||
使用示例:
|
||||
|
||||
@register(
|
||||
name="插件名称",
|
||||
description="插件描述",
|
||||
version="插件版本",
|
||||
author="插件作者"
|
||||
)
|
||||
class MyPlugin(BasePlugin):
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def handler(
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册事件监听器
|
||||
|
||||
使用示例:
|
||||
|
||||
class MyPlugin(BasePlugin):
|
||||
|
||||
@handler(NormalMessageResponded)
|
||||
async def on_normal_message_responded(self, ctx: EventContext):
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def llm_func(
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册内容函数
|
||||
|
||||
使用示例:
|
||||
|
||||
class MyPlugin(BasePlugin):
|
||||
|
||||
@llm_func("access_the_web_page")
|
||||
async def _(self, query, url: str, brief_len: int):
|
||||
\"""Call this function to search about the question before you answer any questions.
|
||||
- Do not search through google.com at any time.
|
||||
- If you need to search somthing, visit https://www.sogou.com/web?query=<something>.
|
||||
- If user ask you to open a url (start with http:// or https://), visit it directly.
|
||||
- Summary the plain content result by yourself, DO NOT directly output anything in the result you got.
|
||||
|
||||
Args:
|
||||
url(str): url to visit
|
||||
brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096
|
||||
|
||||
Returns:
|
||||
str: plain text content of the web page or error message(starts with 'error:')
|
||||
\"""
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BasePlugin(metaclass=abc.ABCMeta):
|
||||
"""插件基类"""
|
||||
|
||||
host: APIHost
|
||||
"""API宿主"""
|
||||
|
||||
ap: app.Application
|
||||
"""应用程序对象"""
|
||||
|
||||
def __init__(self, host: APIHost):
|
||||
self.host = host
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化插件"""
|
||||
pass
|
||||
|
||||
|
||||
class APIHost:
|
||||
|
@ -61,8 +137,10 @@ class EventContext:
|
|||
"""事件编号"""
|
||||
|
||||
host: APIHost = None
|
||||
"""API宿主"""
|
||||
|
||||
event: events.BaseEventModel = None
|
||||
"""此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义"""
|
||||
|
||||
__prevent_default__ = False
|
||||
"""是否阻止默认行为"""
|
||||
|
|
|
@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
|
|||
|
||||
|
||||
class BaseEventModel(pydantic.BaseModel):
|
||||
"""事件模型基类"""
|
||||
|
||||
query: typing.Union[core_entities.Query, None]
|
||||
"""此次请求的query对象,非请求过程的事件时为None"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# 此模块已过时
|
||||
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
|
||||
# 最早将于 v3.4 移除此模块
|
||||
|
||||
from . events import *
|
||||
from . context import EventContext, APIHost as PluginHost
|
||||
|
||||
|
|
|
@ -5,11 +5,10 @@ import pkgutil
|
|||
import importlib
|
||||
import traceback
|
||||
|
||||
from CallingGPT.entities.namespace import get_func_schema
|
||||
|
||||
from .. import loader, events, context, models, host
|
||||
from ...core import entities as core_entities
|
||||
from ...provider.tools import entities as tools_entities
|
||||
from ...utils import funcschema
|
||||
|
||||
|
||||
class PluginLoader(loader.PluginLoader):
|
||||
|
@ -29,6 +28,10 @@ class PluginLoader(loader.PluginLoader):
|
|||
setattr(models, 'on', self.on)
|
||||
setattr(models, 'func', self.func)
|
||||
|
||||
setattr(context, 'register', self.register)
|
||||
setattr(context, 'handler', self.handler)
|
||||
setattr(context, 'llm_func', self.llm_func)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -57,6 +60,8 @@ class PluginLoader(loader.PluginLoader):
|
|||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def on(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
|
@ -83,6 +88,8 @@ class PluginLoader(loader.PluginLoader):
|
|||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def func(
|
||||
self,
|
||||
name: str=None,
|
||||
|
@ -91,10 +98,11 @@ class PluginLoader(loader.PluginLoader):
|
|||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = get_func_schema(func)
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
async def handler(
|
||||
plugin: context.BasePlugin,
|
||||
query: core_entities.Query,
|
||||
*args,
|
||||
**kwargs
|
||||
|
@ -116,6 +124,46 @@ class PluginLoader(loader.PluginLoader):
|
|||
|
||||
return wrapper
|
||||
|
||||
def handler(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册事件处理器"""
|
||||
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
self._current_container.event_handlers[event] = func
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def llm_func(
|
||||
self,
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册内容函数"""
|
||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
llm_function = tools_entities.LLMFunction(
|
||||
name=function_name,
|
||||
human_desc='',
|
||||
description=function_schema['description'],
|
||||
enable=True,
|
||||
parameters=function_schema['parameters'],
|
||||
func=func,
|
||||
)
|
||||
|
||||
self._current_container.content_functions.append(llm_function)
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
async def _walk_plugin_path(
|
||||
self,
|
||||
module,
|
|
@ -5,7 +5,7 @@ import traceback
|
|||
|
||||
from ..core import app
|
||||
from . import context, loader, events, installer, setting, models
|
||||
from .loaders import legacy
|
||||
from .loaders import classic
|
||||
from .installers import github
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ class PluginManager:
|
|||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.loader = legacy.PluginLoader(ap)
|
||||
self.loader = classic.PluginLoader(ap)
|
||||
self.installer = github.GitHubRepoInstaller(ap)
|
||||
self.setting = setting.SettingManager(ap)
|
||||
self.api_host = context.APIHost(ap)
|
||||
|
@ -52,6 +52,9 @@ class PluginManager:
|
|||
for plugin in self.plugins:
|
||||
try:
|
||||
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
||||
plugin.plugin_inst.ap = self.ap
|
||||
plugin.plugin_inst.host = self.api_host
|
||||
await plugin.plugin_inst.initialize()
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
|
||||
self.ap.logger.exception(e)
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数
|
||||
# 各个事件模型请从 pkg.plugin.events 引入
|
||||
# 最早将于 v3.4 移除此模块
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
|
|
@ -22,15 +22,20 @@ class ToolCall(pydantic.BaseModel):
|
|||
class Message(pydantic.BaseModel):
|
||||
"""消息"""
|
||||
|
||||
role: str # user, system, assistant, tool, command
|
||||
role: str # user, system, assistant, tool, command, plugin
|
||||
"""消息的角色"""
|
||||
|
||||
name: typing.Optional[str] = None
|
||||
"""名称,仅函数调用返回时设置"""
|
||||
|
||||
content: typing.Optional[str] = None
|
||||
"""内容"""
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
||||
"""函数调用,不再受支持,请使用tool_calls"""
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
"""工具调用"""
|
||||
|
||||
tool_call_id: typing.Optional[str] = None
|
||||
|
||||
|
|
|
@ -38,6 +38,15 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
|
|||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""请求
|
||||
"""请求API
|
||||
|
||||
对话前文可以从 query 对象中获取。
|
||||
可以多次yield消息对象。
|
||||
|
||||
Args:
|
||||
query (core_entities.Query): 本次请求的上下文对象
|
||||
|
||||
Yields:
|
||||
pkg.provider.entities.Message: 返回消息对象
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -10,5 +10,7 @@ class Prompt(pydantic.BaseModel):
|
|||
"""供AI使用的Prompt"""
|
||||
|
||||
name: str
|
||||
"""名称"""
|
||||
|
||||
messages: list[entities.Message]
|
||||
"""消息列表"""
|
||||
|
|
|
@ -36,7 +36,7 @@ class PromptLoader(metaclass=abc.ABCMeta):
|
|||
|
||||
@abc.abstractmethod
|
||||
async def load(self):
|
||||
"""加载Prompt
|
||||
"""加载Prompt,存放到prompts列表中
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import traceback
|
|||
|
||||
from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
from ...plugin import context as plugin_context
|
||||
|
||||
|
||||
class ToolManager:
|
||||
|
@ -28,6 +29,15 @@ class ToolManager:
|
|||
return function
|
||||
return None
|
||||
|
||||
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
|
||||
"""获取函数和插件
|
||||
"""
|
||||
for plugin in self.ap.plugin_mgr.plugins:
|
||||
for function in plugin.content_functions:
|
||||
if function.name == name:
|
||||
return function, plugin
|
||||
return None, None
|
||||
|
||||
async def get_all_functions(self) -> list[entities.LLMFunction]:
|
||||
"""获取所有函数
|
||||
"""
|
||||
|
@ -68,7 +78,7 @@ class ToolManager:
|
|||
|
||||
try:
|
||||
|
||||
function = await self.get_function(name)
|
||||
function, plugin = await self.get_function_and_plugin(name)
|
||||
if function is None:
|
||||
return None
|
||||
|
||||
|
@ -79,7 +89,7 @@ class ToolManager:
|
|||
**parameters
|
||||
}
|
||||
|
||||
return await function.func(**parameters)
|
||||
return await function.func(plugin, **parameters)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
|
||||
traceback.print_exc()
|
||||
|
|
116
pkg/utils/funcschema.py
Normal file
116
pkg/utils/funcschema.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import sys
|
||||
import re
|
||||
import inspect
|
||||
|
||||
|
||||
def get_func_schema(function: callable) -> dict:
|
||||
"""
|
||||
Return the data schema of a function.
|
||||
{
|
||||
"function": function,
|
||||
"description": "function description",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parameter_a": {
|
||||
"type": "str",
|
||||
"description": "parameter_a description"
|
||||
},
|
||||
"parameter_b": {
|
||||
"type": "int",
|
||||
"description": "parameter_b description"
|
||||
},
|
||||
"parameter_c": {
|
||||
"type": "str",
|
||||
"description": "parameter_c description",
|
||||
"enum": ["a", "b", "c"]
|
||||
},
|
||||
},
|
||||
"required": ["parameter_a", "parameter_b"]
|
||||
}
|
||||
}
|
||||
"""
|
||||
func_doc = function.__doc__
|
||||
# Google Style Docstring
|
||||
if func_doc is None:
|
||||
raise Exception("Function {} has no docstring.".format(function.__name__))
|
||||
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
|
||||
# extract doc of args from docstring
|
||||
doc_spt = func_doc.split('\n\n')
|
||||
desc = doc_spt[0]
|
||||
args = doc_spt[1] if len(doc_spt) > 1 else ""
|
||||
returns = doc_spt[2] if len(doc_spt) > 2 else ""
|
||||
|
||||
# extract args
|
||||
# delete the first line of args
|
||||
arg_lines = args.split('\n')[1:]
|
||||
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
|
||||
args_doc = {}
|
||||
for arg_line in arg_lines:
|
||||
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
|
||||
if len(doc_tuple) == 0:
|
||||
continue
|
||||
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
|
||||
|
||||
# extract returns
|
||||
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
|
||||
|
||||
params = enumerate(inspect.signature(function).parameters.values())
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"required": [],
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
|
||||
for i, param in params:
|
||||
|
||||
# 排除 self, query
|
||||
if param.name in ['self', 'query']:
|
||||
continue
|
||||
|
||||
param_type = param.annotation.__name__
|
||||
|
||||
type_name_mapping = {
|
||||
"str": "string",
|
||||
"int": "integer",
|
||||
"float": "number",
|
||||
"bool": "boolean",
|
||||
"list": "array",
|
||||
"dict": "object",
|
||||
}
|
||||
|
||||
if param_type in type_name_mapping:
|
||||
param_type = type_name_mapping[param_type]
|
||||
|
||||
parameters['properties'][param.name] = {
|
||||
"type": param_type,
|
||||
"description": args_doc[param.name],
|
||||
}
|
||||
|
||||
# add schema for array
|
||||
if param_type == "array":
|
||||
# extract type of array, the int of list[int]
|
||||
# use re
|
||||
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
|
||||
|
||||
array_type = 'string'
|
||||
|
||||
if len(array_type_tuple) > 0:
|
||||
array_type = array_type_tuple[0]
|
||||
|
||||
if array_type in type_name_mapping:
|
||||
array_type = type_name_mapping[array_type]
|
||||
|
||||
parameters['properties'][param.name]["items"] = {
|
||||
"type": array_type,
|
||||
}
|
||||
|
||||
if param.default is inspect.Parameter.empty:
|
||||
parameters["required"].append(param.name)
|
||||
|
||||
return {
|
||||
"function": function,
|
||||
"description": desc,
|
||||
"parameters": parameters,
|
||||
}
|
|
@ -7,7 +7,6 @@ aiocqhttp
|
|||
qq-botpy
|
||||
nakuru-project-idk
|
||||
Pillow
|
||||
CallingGPT
|
||||
tiktoken
|
||||
PyYaml
|
||||
aiohttp
|
||||
|
|
Loading…
Reference in New Issue
Block a user