Merge pull request #735 from RockChinQ/feat/plugin-api

Feat: 插件异步 API
This commit is contained in:
Junyan Qin 2024-03-22 17:10:06 +08:00 committed by GitHub
commit 995d1f61d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 476 additions and 58 deletions

View File

@ -20,6 +20,8 @@ class CommandReturn(pydantic.BaseModel):
image: typing.Optional[mirai.Image]
error: typing.Optional[errors.CommandError]= None
"""错误
"""
class Config:
arbitrary_types_allowed = True
@ -30,17 +32,40 @@ class ExecuteContext(pydantic.BaseModel):
"""
query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session
"""本次消息所属的会话对象"""
command_text: str
"""命令完整文本"""
command: str
"""命令名称"""
crt_command: str
"""当前命令
多级命令中crt_command为当前命令command为根命令
例如!plugin on Webwlkr
处理到plugin时command为plugincrt_command为plugin
处理到on时command为plugincrt_command为on
"""
params: list[str]
"""命令参数
整个命令以空格分割后的参数列表
"""
crt_params: list[str]
"""当前命令参数
多级命令中crt_params为当前命令参数params为根命令参数
例如!plugin on Webwlkr
处理到plugin时params为['on', 'Webwlkr']crt_params为['on', 'Webwlkr']
处理到on时params为['on', 'Webwlkr']crt_params为['Webwlkr']
"""
privilege: int
"""发起人权限"""

View File

@ -52,6 +52,11 @@ def operator_class(
class CommandOperator(metaclass=abc.ABCMeta):
"""命令算子抽象类
以下的参数均不需要在子类中设置只需要在使用装饰器注册类时作为参数传递即可
命令支持级联即一个命令可以有多个子命令子命令可以有子命令以此类推
处理命令时若有子命令会以当前参数列表的第一个参数去匹配子命令若匹配成功则转移到子命令中执行
若没有匹配成功或没有子命令则执行当前命令
"""
ap: app.Application
@ -60,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""名称,搜索到时若符合则使用"""
path: str
"""路径所有父节点的name的连接用于定义命令权限"""
"""路径所有父节点的name的连接用于定义命令权限由管理器在初始化时自动设置。
"""
alias: list[str]
"""同name"""
@ -69,6 +75,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
"""此节点的帮助信息"""
usage: str = None
"""用法"""
parent_class: typing.Union[typing.Type[CommandOperator], None] = None
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
@ -92,4 +99,15 @@ class CommandOperator(metaclass=abc.ABCMeta):
self,
context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""实现此方法以执行命令
支持多次yield以返回多个结果
例如一个安装插件的命令可能会有下载解压安装等多个步骤每个步骤都可以返回一个结果
Args:
context (entities.ExecuteContext): 命令执行上下文
Yields:
entities.CommandReturn: 命令返回封装
"""
pass

View File

@ -8,15 +8,12 @@ from .. import model as file_model
class JSONConfigFile(file_model.ConfigFile):
"""JSON配置文件"""
config_file_name: str = None
"""配置文件名"""
template_file_name: str = None
"""模板文件名"""
def __init__(self, config_file_name: str, template_file_name: str) -> None:
def __init__(
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
) -> None:
self.config_file_name = config_file_name
self.template_file_name = template_file_name
self.template_data = template_data
def exists(self) -> bool:
return os.path.exists(self.config_file_name)
@ -29,23 +26,24 @@ class JSONConfigFile(file_model.ConfigFile):
if not self.exists():
await self.create()
with open(self.config_file_name, 'r', encoding='utf-8') as f:
cfg = json.load(f)
if self.template_file_name is not None:
with open(self.config_file_name, "r", encoding="utf-8") as f:
cfg = json.load(f)
# 从模板文件中进行补全
with open(self.template_file_name, 'r', encoding='utf-8') as f:
template_cfg = json.load(f)
with open(self.template_file_name, "r", encoding="utf-8") as f:
self.template_data = json.load(f)
for key in template_cfg:
for key in self.template_data:
if key not in cfg:
cfg[key] = template_cfg[key]
cfg[key] = self.template_data[key]
return cfg
async def save(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f:
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)
def save_sync(self, cfg: dict):
with open(self.config_file_name, 'w', encoding='utf-8') as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)
with open(self.config_file_name, "w", encoding="utf-8") as f:
json.dump(cfg, f, indent=4, ensure_ascii=False)

View File

@ -43,11 +43,12 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
return cfg_mgr
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None) -> ConfigManager:
"""加载JSON配置文件"""
cfg_inst = json_file.JSONConfigFile(
config_name,
template_name
template_name,
template_data
)
cfg_mgr = ConfigManager(cfg_inst)

View File

@ -7,6 +7,7 @@ from ..core import app
preregistered_migrations: list[typing.Type[Migration]] = []
"""当前阶段暂不支持扩展"""
def migration_class(name: str, number: int):
"""注册一个迁移

View File

@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta):
template_file_name: str = None
"""模板文件名"""
template_data: dict = None
"""模板数据"""
@abc.abstractmethod
def exists(self) -> bool:
pass

View File

@ -21,7 +21,7 @@ from ..utils import version as version_mgr, proxy as proxy_mgr
class Application:
"""运行时应用对象和上下文"""
im_mgr: im_mgr.PlatformManager = None
platform_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None
@ -85,10 +85,9 @@ class Application:
tasks = []
try:
tasks = [
asyncio.create_task(self.im_mgr.run()),
asyncio.create_task(self.platform_mgr.run()),
asyncio.create_task(self.ctrl.run())
]

View File

@ -10,7 +10,6 @@ required_deps = {
"botpy": "qq-botpy",
"PIL": "pillow",
"nakuru": "nakuru-project-idk",
"CallingGPT": "CallingGPT",
"tiktoken": "tiktoken",
"yaml": "pyyaml",
"aiohttp": "aiohttp",

View File

@ -32,43 +32,43 @@ class Query(pydantic.BaseModel):
"""请求ID添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型platform设置"""
"""会话类型platform处理阶段设置"""
launcher_id: int
"""会话IDplatform设置"""
"""会话IDplatform处理阶段设置"""
sender_id: int
"""发送者IDplatform设置"""
"""发送者IDplatform处理阶段设置"""
message_event: mirai.MessageEvent
"""事件platform收到的事件"""
"""事件platform收到的原始事件"""
message_chain: mirai.MessageChain
"""消息链platform收到的消息链"""
"""消息链platform收到的原始消息链"""
adapter: msadapter.MessageSourceAdapter
"""适配器对象"""
"""消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器设置"""
"""会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器设置"""
"""历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[sysprompt_entities.Prompt] = None
"""情景预设内容,由前置处理器设置"""
"""情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器设置"""
"""此次请求的用户消息对象,由前置处理器阶段设置"""
use_model: typing.Optional[entities.LLMModelInfo] = None
"""使用的模型,由前置处理器设置"""
"""使用的模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器设置"""
"""使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = []
"""provider生成的回复消息对象列表"""
"""Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链从resp_messages包装而得"""

View File

@ -7,7 +7,10 @@ from . import app
preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。"""
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。
当前阶段暂不支持扩展
"""
def stage_class(
name: str

View File

@ -86,7 +86,7 @@ class BuildAppStage(stage.BootingStage):
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.im_mgr = im_mgr_inst
ap.platform_mgr = im_mgr_inst
stage_mgr = stagemgr.StageManager(ap)
await stage_mgr.initialize()

View File

@ -31,15 +31,24 @@ class EnableStage(enum.Enum):
class FilterResult(pydantic.BaseModel):
level: ResultLevel
"""结果等级
对于前置处理阶段只要有任意一个返回 非PASS 的内容过滤器结果就会中断处理
对于后置处理阶段当且内容过滤器返回 BLOCK 会中断处理
"""
replacement: str
"""替换后的消息"""
"""替换后的消息
内容过滤器可以进行一些遮掩处理然后把遮掩后的消息返回
若没有修改内容也需要返回原消息
"""
user_notice: str
"""不通过时,用户提示消息"""
"""不通过时,若此值不为空,将对用户提示消息"""
console_notice: str
"""不通过时,控制台提示消息"""
"""不通过时,若此值不为空,将在控制台提示消息"""
class ManagerResultLevel(enum.Enum):

View File

@ -46,6 +46,11 @@ class ContentFilter(metaclass=abc.ABCMeta):
@property
def enable_stages(self):
"""启用的阶段
默认为消息请求AI前后的两个阶段
entity.EnableStage.PRE: 消息请求AI前此时需要检查的内容是用户的输入消息
entity.EnableStage.POST: 消息请求AI后此时需要检查的内容是AI的回复消息
"""
return [
entities.EnableStage.PRE,
@ -60,5 +65,14 @@ class ContentFilter(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def process(self, message: str) -> entities.FilterResult:
"""处理消息
分为前后阶段具体取决于 enable_stages 的值
对于内容过滤器来说不需要考虑消息所处的阶段只需要检查消息内容即可
Args:
message (str): 需要检查的内容
Returns:
entities.FilterResult: 过滤结果具体内容请查看 entities.FilterResult 类的文档
"""
raise NotImplementedError

View File

@ -68,7 +68,7 @@ class Controller:
"""检查输出
"""
if result.user_notice:
await self.ap.im_mgr.send(
await self.ap.platform_mgr.send(
query.message_event,
result.user_notice,
query.adapter

View File

@ -15,6 +15,15 @@ preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
def strategy_class(
name: str
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
"""长文本处理策略类装饰器
Args:
name (str): 策略名称
Returns:
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器
"""
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
assert issubclass(cls, LongTextStrategy)
@ -43,4 +52,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
"""处理长文本
platform.json 中配置 long-text-process 字段只要 文本长度超过了 threshold 就会调用此方法
Args:
message (str): 消息
query (core_entities.Query): 此次请求的上下文对象
Returns:
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
"""
return []

View File

@ -39,7 +39,14 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='plugin',
content=str(mc),
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@ -18,6 +18,7 @@ def algo_class(name: str):
class ReteLimitAlgo(metaclass=abc.ABCMeta):
"""限流算法抽象类"""
name: str = None
@ -31,9 +32,27 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
"""进入处理流程
这个方法对等待是友好的意味着算法可以实现在这里等待一段时间以控制速率
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
Returns:
bool: 是否允许进入处理流程若返回false则直接丢弃该请求
"""
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
"""退出处理流程
Args:
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
launcher_id (int): 请求者ID
"""
raise NotImplementedError

View File

@ -29,7 +29,7 @@ class SendResponseBackStage(stage.PipelineStage):
await asyncio.sleep(random_delay)
await self.ap.im_mgr.send(
await self.ap.platform_mgr.send(
query.message_event,
query.resp_message_chain,
adapter=query.adapter

View File

@ -29,6 +29,13 @@ class ResponseWrapper(stage.PipelineStage):
if query.resp_messages[-1].role == 'command':
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
elif query.resp_messages[-1].role == 'plugin':
query.resp_message_chain = mirai.MessageChain(query.resp_messages[-1].content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@ -14,6 +14,14 @@ preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
def adapter_class(
name: str
):
"""消息平台适配器类装饰器
Args:
name (str): 适配器名称
Returns:
typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器
"""
def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]:
cls.name = name
preregistered_adapters.append(cls)
@ -27,12 +35,19 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
name: str
bot_account_id: int
"""机器人账号ID需要在初始化时设置"""
config: dict
ap: app.Application
def __init__(self, config: dict, ap: app.Application):
"""初始化适配器
Args:
config (dict): 对应的配置
ap (app.Application): 应用上下文
"""
self.config = config
self.ap = ap

View File

@ -146,7 +146,7 @@ class PlatformManager:
if len(self.adapters) == 0:
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):

View File

@ -9,10 +9,86 @@ from ..provider.tools import entities as tools_entities
from ..core import app
def register(
name: str,
description: str,
version: str,
author: str
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
"""注册插件类
使用示例
@register(
name="插件名称",
description="插件描述",
version="插件版本",
author="插件作者"
)
class MyPlugin(BasePlugin):
pass
"""
pass
def handler(
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件监听器
使用示例
class MyPlugin(BasePlugin):
@handler(NormalMessageResponded)
async def on_normal_message_responded(self, ctx: EventContext):
pass
"""
pass
def llm_func(
name: str=None,
) -> typing.Callable:
"""注册内容函数
使用示例
class MyPlugin(BasePlugin):
@llm_func("access_the_web_page")
async def _(self, query, url: str, brief_len: int):
\"""Call this function to search about the question before you answer any questions.
- Do not search through google.com at any time.
- If you need to search somthing, visit https://www.sogou.com/web?query=<something>.
- If user ask you to open a url (start with http:// or https://), visit it directly.
- Summary the plain content result by yourself, DO NOT directly output anything in the result you got.
Args:
url(str): url to visit
brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096
Returns:
str: plain text content of the web page or error message(starts with 'error:')
\"""
"""
pass
class BasePlugin(metaclass=abc.ABCMeta):
"""插件基类"""
host: APIHost
"""API宿主"""
ap: app.Application
"""应用程序对象"""
def __init__(self, host: APIHost):
self.host = host
async def initialize(self):
"""初始化插件"""
pass
class APIHost:
@ -61,8 +137,10 @@ class EventContext:
"""事件编号"""
host: APIHost = None
"""API宿主"""
event: events.BaseEventModel = None
"""此次事件的对象具体类型为handler注册时指定监听的类型可查看events.py中的定义"""
__prevent_default__ = False
"""是否阻止默认行为"""

View File

@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
class BaseEventModel(pydantic.BaseModel):
"""事件模型基类"""
query: typing.Union[core_entities.Query, None]
"""此次请求的query对象非请求过程的事件时为None"""
class Config:
arbitrary_types_allowed = True

View File

@ -1,3 +1,7 @@
# 此模块已过时
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
# 最早将于 v3.4 移除此模块
from . events import *
from . context import EventContext, APIHost as PluginHost

View File

@ -5,11 +5,10 @@ import pkgutil
import importlib
import traceback
from CallingGPT.entities.namespace import get_func_schema
from .. import loader, events, context, models, host
from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities
from ...utils import funcschema
class PluginLoader(loader.PluginLoader):
@ -29,6 +28,10 @@ class PluginLoader(loader.PluginLoader):
setattr(models, 'on', self.on)
setattr(models, 'func', self.func)
setattr(context, 'register', self.register)
setattr(context, 'handler', self.handler)
setattr(context, 'llm_func', self.llm_func)
def register(
self,
name: str,
@ -57,6 +60,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def on(
self,
event: typing.Type[events.BaseEventModel]
@ -83,6 +88,8 @@ class PluginLoader(loader.PluginLoader):
return wrapper
# 过时
# 最早将于 v3.4 版本移除
def func(
self,
name: str=None,
@ -91,10 +98,11 @@ class PluginLoader(loader.PluginLoader):
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = get_func_schema(func)
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler(
plugin: context.BasePlugin,
query: core_entities.Query,
*args,
**kwargs
@ -116,6 +124,46 @@ class PluginLoader(loader.PluginLoader):
return wrapper
def handler(
self,
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
self._current_container.event_handlers[event] = func
return func
return wrapper
def llm_func(
self,
name: str=None,
) -> typing.Callable:
"""注册内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=func,
)
self._current_container.content_functions.append(llm_function)
return func
return wrapper
async def _walk_plugin_path(
self,
module,

View File

@ -5,7 +5,7 @@ import traceback
from ..core import app
from . import context, loader, events, installer, setting, models
from .loaders import legacy
from .loaders import classic
from .installers import github
@ -26,7 +26,7 @@ class PluginManager:
def __init__(self, ap: app.Application):
self.ap = ap
self.loader = legacy.PluginLoader(ap)
self.loader = classic.PluginLoader(ap)
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
@ -52,6 +52,9 @@ class PluginManager:
for plugin in self.plugins:
try:
plugin.plugin_inst = plugin.plugin_class(self.api_host)
plugin.plugin_inst.ap = self.ap
plugin.plugin_inst.host = self.api_host
await plugin.plugin_inst.initialize()
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
self.ap.logger.exception(e)

View File

@ -1,3 +1,7 @@
# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数
# 各个事件模型请从 pkg.plugin.events 引入
# 最早将于 v3.4 移除此模块
from __future__ import annotations
import typing

View File

@ -22,15 +22,20 @@ class ToolCall(pydantic.BaseModel):
class Message(pydantic.BaseModel):
"""消息"""
role: str # user, system, assistant, tool, command
role: str # user, system, assistant, tool, command, plugin
"""消息的角色"""
name: typing.Optional[str] = None
"""名称,仅函数调用返回时设置"""
content: typing.Optional[str] = None
"""内容"""
function_call: typing.Optional[FunctionCall] = None
"""函数调用不再受支持请使用tool_calls"""
tool_calls: typing.Optional[list[ToolCall]] = None
"""工具调用"""
tool_call_id: typing.Optional[str] = None

View File

@ -38,6 +38,15 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求
"""请求API
对话前文可以从 query 对象中获取
可以多次yield消息对象
Args:
query (core_entities.Query): 本次请求的上下文对象
Yields:
pkg.provider.entities.Message: 返回消息对象
"""
raise NotImplementedError

View File

@ -10,5 +10,7 @@ class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt"""
name: str
"""名称"""
messages: list[entities.Message]
"""消息列表"""

View File

@ -36,7 +36,7 @@ class PromptLoader(metaclass=abc.ABCMeta):
@abc.abstractmethod
async def load(self):
"""加载Prompt
"""加载Prompt存放到prompts列表中
"""
raise NotImplementedError

View File

@ -5,6 +5,7 @@ import traceback
from ...core import app, entities as core_entities
from . import entities
from ...plugin import context as plugin_context
class ToolManager:
@ -28,6 +29,15 @@ class ToolManager:
return function
return None
async def get_function_and_plugin(self, name: str) -> typing.Tuple[entities.LLMFunction, plugin_context.BasePlugin]:
"""获取函数和插件
"""
for plugin in self.ap.plugin_mgr.plugins:
for function in plugin.content_functions:
if function.name == name:
return function, plugin
return None, None
async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数
"""
@ -68,7 +78,7 @@ class ToolManager:
try:
function = await self.get_function(name)
function, plugin = await self.get_function_and_plugin(name)
if function is None:
return None
@ -79,7 +89,7 @@ class ToolManager:
**parameters
}
return await function.func(**parameters)
return await function.func(plugin, **parameters)
except Exception as e:
self.ap.logger.error(f'执行函数 {name} 时发生错误: {e}')
traceback.print_exc()

116
pkg/utils/funcschema.py Normal file
View File

@ -0,0 +1,116 @@
import sys
import re
import inspect
def get_func_schema(function: callable) -> dict:
"""
Return the data schema of a function.
{
"function": function,
"description": "function description",
"parameters": {
"type": "object",
"properties": {
"parameter_a": {
"type": "str",
"description": "parameter_a description"
},
"parameter_b": {
"type": "int",
"description": "parameter_b description"
},
"parameter_c": {
"type": "str",
"description": "parameter_c description",
"enum": ["a", "b", "c"]
},
},
"required": ["parameter_a", "parameter_b"]
}
}
"""
func_doc = function.__doc__
# Google Style Docstring
if func_doc is None:
raise Exception("Function {} has no docstring.".format(function.__name__))
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
# extract doc of args from docstring
doc_spt = func_doc.split('\n\n')
desc = doc_spt[0]
args = doc_spt[1] if len(doc_spt) > 1 else ""
returns = doc_spt[2] if len(doc_spt) > 2 else ""
# extract args
# delete the first line of args
arg_lines = args.split('\n')[1:]
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
args_doc = {}
for arg_line in arg_lines:
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
if len(doc_tuple) == 0:
continue
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
# extract returns
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
params = enumerate(inspect.signature(function).parameters.values())
parameters = {
"type": "object",
"required": [],
"properties": {},
}
for i, param in params:
# 排除 self, query
if param.name in ['self', 'query']:
continue
param_type = param.annotation.__name__
type_name_mapping = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
"list": "array",
"dict": "object",
}
if param_type in type_name_mapping:
param_type = type_name_mapping[param_type]
parameters['properties'][param.name] = {
"type": param_type,
"description": args_doc[param.name],
}
# add schema for array
if param_type == "array":
# extract type of array, the int of list[int]
# use re
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
array_type = 'string'
if len(array_type_tuple) > 0:
array_type = array_type_tuple[0]
if array_type in type_name_mapping:
array_type = type_name_mapping[array_type]
parameters['properties'][param.name]["items"] = {
"type": array_type,
}
if param.default is inspect.Parameter.empty:
parameters["required"].append(param.name)
return {
"function": function,
"description": desc,
"parameters": parameters,
}

View File

@ -7,7 +7,6 @@ aiocqhttp
qq-botpy
nakuru-project-idk
Pillow
CallingGPT
tiktoken
PyYaml
aiohttp