refactor: 重构插件系统

This commit is contained in:
RockChinQ 2024-01-29 21:22:27 +08:00
parent b730f17eb6
commit 6cc4688660
53 changed files with 1307 additions and 1993 deletions

View File

@ -1,114 +0,0 @@
"""
使用量统计以及数据上报功能实现
"""
import hashlib
import json
import logging
import threading
import requests
from ..utils import context
from ..utils import updater
class DataGatherer:
"""数据收集器"""
usage = {}
"""各api-key的使用量
以key值md5为key,{
"text": {
"gpt-3.5-turbo": 文字量:int,
},
"image": {
"256x256": 图片数量:int,
}
}为值的字典"""
version_str = "undetermined"
def __init__(self):
self.load_from_db()
try:
self.version_str = updater.get_current_tag() # 从updater模块获取版本号
except:
pass
def get_usage(self, key_md5):
return self.usage[key_md5] if key_md5 in self.usage else {}
def report_text_model_usage(self, model, total_tokens):
"""调用方报告文字模型请求文字使用量"""
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5() # 以key的md5进行储存
if key_md5 not in self.usage:
self.usage[key_md5] = {}
if "text" not in self.usage[key_md5]:
self.usage[key_md5]["text"] = {}
if model not in self.usage[key_md5]["text"]:
self.usage[key_md5]["text"][model] = 0
length = total_tokens
self.usage[key_md5]["text"][model] += length
self.dump_to_db()
def report_image_model_usage(self, size):
"""调用方报告图片模型请求图片使用量"""
key_md5 = context.get_openai_manager().key_mgr.get_using_key_md5()
if key_md5 not in self.usage:
self.usage[key_md5] = {}
if "image" not in self.usage[key_md5]:
self.usage[key_md5]["image"] = {}
if size not in self.usage[key_md5]["image"]:
self.usage[key_md5]["image"][size] = 0
self.usage[key_md5]["image"][size] += 1
self.dump_to_db()
def get_text_length_of_key(self, key):
"""获取指定api-key (明文) 的文字总使用量(本地记录)"""
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage:
return 0
if "text" not in self.usage[key_md5]:
return 0
# 遍历其中所有模型,求和
return sum(self.usage[key_md5]["text"].values())
def get_image_count_of_key(self, key):
"""获取指定api-key (明文) 的图片总使用量(本地记录)"""
key_md5 = hashlib.md5(key.encode('utf-8')).hexdigest()
if key_md5 not in self.usage:
return 0
if "image" not in self.usage[key_md5]:
return 0
# 遍历其中所有模型,求和
return sum(self.usage[key_md5]["image"].values())
def get_total_text_length(self):
"""获取所有api-key的文字总使用量(本地记录)"""
total = 0
for key in self.usage:
if "text" not in self.usage[key]:
continue
total += sum(self.usage[key]["text"].values())
return total
def dump_to_db(self):
context.get_database_manager().dump_usage_json(self.usage)
def load_from_db(self):
json_str = context.get_database_manager().load_usage_json()
if json_str is not None:
self.usage = json.loads(json_str)

View File

@ -4,7 +4,6 @@ import typing
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from ..provider import entities as llm_entities from ..provider import entities as llm_entities
from ..provider.session import entities as session_entities
from . import entities, operator, errors from . import entities, operator, errors
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cfg, cmd, help, version, update
@ -80,7 +79,7 @@ class CommandManager:
self, self,
command_text: str, command_text: str,
query: core_entities.Query, query: core_entities.Query,
session: session_entities.Session session: core_entities.Session
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令 """执行命令
""" """

View File

@ -6,7 +6,6 @@ import pydantic
import mirai import mirai
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from ..provider.session import entities as session_entities
from . import errors, operator from . import errors, operator
@ -28,7 +27,7 @@ class ExecuteContext(pydantic.BaseModel):
query: core_entities.Query query: core_entities.Query
session: session_entities.Session session: core_entities.Session
command_text: str command_text: str

View File

@ -4,7 +4,6 @@ import typing
import abc import abc
from ..core import app, entities as core_entities from ..core import app, entities as core_entities
from ..provider.session import entities as session_entities
from . import entities from . import entities

View File

@ -2,7 +2,6 @@ from __future__ import annotations
from typing import AsyncGenerator from typing import AsyncGenerator
from .. import operator, entities, cmdmgr from .. import operator, entities, cmdmgr
from ...plugin import host as plugin_host
@operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func') @operator.operator_class(name="func", help="查看所有已注册的内容函数", usage='!func')
@ -13,7 +12,10 @@ class FuncOperator(operator.CommandOperator):
reply_str = "当前已加载的内容函数: \n\n" reply_str = "当前已加载的内容函数: \n\n"
index = 1 index = 1
for func in self.ap.tool_mgr.all_functions:
all_functions = await self.ap.tool_mgr.get_all_functions()
for func in all_functions:
reply_str += "{}. {}{}:\n{}\n\n".format( reply_str += "{}. {}{}:\n{}\n\n".format(
index, index,
("(已禁用) " if not func.enable else ""), ("(已禁用) " if not func.enable else ""),

View File

@ -3,8 +3,6 @@ import typing
import traceback import traceback
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, cmdmgr, errors
from ...plugin import host as plugin_host
from ...utils import updater
from ...core import app from ...core import app
@ -20,16 +18,15 @@ class PluginOperator(operator.CommandOperator):
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
plugin_list = plugin_host.__plugins__ plugin_list = self.ap.plugin_mgr.plugins
reply_str = "所有插件({}):\n".format(len(plugin_host.__plugins__)) reply_str = "所有插件({}):\n".format(len(plugin_list))
idx = 0 idx = 0
for key in plugin_host.iter_plugins_name(): for plugin in plugin_list:
plugin = plugin_list[key]
reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\ reply_str += "\n#{} {} {}\n{}\nv{}\n作者: {}\n"\
.format((idx+1), plugin['name'], .format((idx+1), plugin.plugin_name,
"[已禁用]" if not plugin['enabled'] else "", "[已禁用]" if not plugin.enabled else "",
plugin['description'], plugin.plugin_description,
plugin['version'], plugin['author']) plugin.plugin_version, plugin.plugin_author)
# TODO 从元数据调远程地址 # TODO 从元数据调远程地址
# if updater.is_repo("/".join(plugin['path'].split('/')[:-1])): # if updater.is_repo("/".join(plugin['path'].split('/')[:-1])):
@ -63,7 +60,7 @@ class PluginGetOperator(operator.CommandOperator):
yield entities.CommandReturn(text="正在安装插件...") yield entities.CommandReturn(text="正在安装插件...")
try: try:
plugin_host.install_plugin(repo) await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件") yield entities.CommandReturn(text="插件安装成功,请重启程序以加载插件")
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -89,11 +86,11 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_path_name is not None: if plugin_container is not None:
yield entities.CommandReturn(text="正在更新插件...") yield entities.CommandReturn(text="正在更新插件...")
plugin_host.update_plugin(plugin_name) await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件") yield entities.CommandReturn(text="插件更新成功,请重启程序以加载插件")
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件")) yield entities.CommandReturn(error=errors.CommandError("插件更新失败: 未找到插件"))
@ -115,17 +112,17 @@ class PluginUpdateAllOperator(operator.CommandOperator):
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
try: try:
plugins = [] plugins = [
p.plugin_name
for key in plugin_host.__plugins__: for p in self.ap.plugin_mgr.plugins
plugins.append(key) ]
if plugins: if plugins:
yield entities.CommandReturn(text="正在更新插件...") yield entities.CommandReturn(text="正在更新插件...")
updated = [] updated = []
try: try:
for plugin_name in plugins: for plugin_name in plugins:
plugin_host.update_plugin(plugin_name) await self.ap.plugin_mgr.update_plugin(plugin_name)
updated.append(plugin_name) updated.append(plugin_name)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
@ -157,11 +154,11 @@ class PluginDelOperator(operator.CommandOperator):
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
plugin_path_name = plugin_host.get_plugin_path_name_by_plugin_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_path_name is not None: if plugin_container is not None:
yield entities.CommandReturn(text="正在删除插件...") yield entities.CommandReturn(text="正在删除插件...")
plugin_host.uninstall_plugin(plugin_name) await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件") yield entities.CommandReturn(text="插件删除成功,请重启程序以加载插件")
else: else:
yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件")) yield entities.CommandReturn(error=errors.CommandError("插件删除失败: 未找到插件"))
@ -171,12 +168,15 @@ class PluginDelOperator(operator.CommandOperator):
def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application): def update_plugin_status(plugin_name: str, new_status: bool, ap: app.Application):
if plugin_name in plugin_host.__plugins__: if ap.plugin_mgr.get_plugin_by_name(plugin_name) is not None:
plugin_host.__plugins__[plugin_name]['enabled'] = new_status for plugin in ap.plugin_mgr.plugins:
if plugin.plugin_name == plugin_name:
plugin.enabled = new_status
for func in ap.tool_mgr.all_functions: for func in plugin.content_functions:
if func.name.startswith(plugin_name+'-'): func.enable = new_status
func.enable = new_status
break
return True return True
else: else:

View File

@ -4,7 +4,6 @@ import typing
import traceback import traceback
from .. import operator, entities, cmdmgr, errors from .. import operator, entities, cmdmgr, errors
from ...utils import updater
@operator.operator_class( @operator.operator_class(
@ -22,7 +21,7 @@ class UpdateCommand(operator.CommandOperator):
try: try:
yield entities.CommandReturn(text="正在进行更新...") yield entities.CommandReturn(text="正在进行更新...")
if updater.update_all(): if await self.ap.ver_mgr.update_all():
yield entities.CommandReturn(text="更新完成,请重启程序以应用更新") yield entities.CommandReturn(text="更新完成,请重启程序以应用更新")
else: else:
yield entities.CommandReturn(text="当前已是最新版本") yield entities.CommandReturn(text="当前已是最新版本")

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import typing import typing
from .. import operator, cmdmgr, entities, errors from .. import operator, cmdmgr, entities, errors
from ...utils import updater
@operator.operator_class( @operator.operator_class(
@ -17,10 +16,10 @@ class VersionCommand(operator.CommandOperator):
self, self,
context: entities.ExecuteContext context: entities.ExecuteContext
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[entities.CommandReturn, None]:
reply_str = f"当前版本: \n{updater.get_current_version_info()}" reply_str = f"当前版本: \n{await self.ap.ver_mgr.get_current_version_info()}"
try: try:
if updater.is_new_version_available(): if await self.ap.ver_mgr.is_new_version_available():
reply_str += "\n\n有新版本可用, 使用 !update 更新" reply_str += "\n\n有新版本可用, 使用 !update 更新"
except: except:
pass pass

View File

@ -26,6 +26,9 @@ class JSONConfigFile(file_model.ConfigFile):
async def load(self) -> dict: async def load(self) -> dict:
if not self.exists():
await self.create()
with open(self.config_file_name, 'r', encoding='utf-8') as f: with open(self.config_file_name, 'r', encoding='utf-8') as f:
cfg = json.load(f) cfg = json.load(f)

View File

@ -3,22 +3,22 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
from ..platform import manager as qqbot_mgr from ..platform import manager as im_mgr
from ..provider.session import sessionmgr as llm_session_mgr from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr from ..config import manager as config_mgr
from ..database import manager as database_mgr # from ..utils.center import v2 as center_mgr
from ..utils.center import v2 as center_mgr
from ..command import cmdmgr from ..command import cmdmgr
from ..plugin import host as plugin_host from ..plugin import manager as plugin_mgr
from . import pool, controller from . import pool, controller
from ..pipeline import stagemgr from ..pipeline import stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr
class Application: class Application:
im_mgr: qqbot_mgr.QQBotManager = None im_mgr: im_mgr.QQBotManager = None
cmd_mgr: cmdmgr.CommandManager = None cmd_mgr: cmdmgr.CommandManager = None
@ -34,9 +34,9 @@ class Application:
tips_mgr: config_mgr.ConfigManager = None tips_mgr: config_mgr.ConfigManager = None
db_mgr: database_mgr.DatabaseManager = None # ctr_mgr: center_mgr.V2CenterAPI = None
ctr_mgr: center_mgr.V2CenterAPI = None plugin_mgr: plugin_mgr.PluginManager = None
query_pool: pool.QueryPool = None query_pool: pool.QueryPool = None
@ -44,24 +44,29 @@ class Application:
stage_mgr: stagemgr.StageManager = None stage_mgr: stagemgr.StageManager = None
ver_mgr: version_mgr.VersionManager = None
proxy_mgr: proxy_mgr.ProxyManager = None
logger: logging.Logger = None logger: logging.Logger = None
def __init__(self): def __init__(self):
pass pass
async def initialize(self): async def initialize(self):
plugin_host.initialize_plugins() pass
# 把现有的所有内容函数加到toolmgr里 # 把现有的所有内容函数加到toolmgr里
for func in plugin_host.__callable_functions__: # for func in plugin_host.__callable_functions__:
self.tool_mgr.register_legacy_function( # self.tool_mgr.register_legacy_function(
name=func['name'], # name=func['name'],
description=func['description'], # description=func['description'],
parameters=func['parameters'], # parameters=func['parameters'],
func=plugin_host.__function_inst_map__[func['name']] # func=plugin_host.__function_inst_map__[func['name']]
) # )
async def run(self): async def run(self):
await self.plugin_mgr.load_plugins()
tasks = [ tasks = [
asyncio.create_task(self.im_mgr.run()), asyncio.create_task(self.im_mgr.run()),

View File

@ -13,17 +13,15 @@ from . import pool
from . import controller from . import controller
from ..pipeline import stagemgr from ..pipeline import stagemgr
from ..audit import identifier from ..audit import identifier
from ..database import manager as db_mgr
from ..provider.session import sessionmgr as llm_session_mgr from ..provider.session import sessionmgr as llm_session_mgr
from ..provider.requester import modelmgr as llm_model_mgr from ..provider.requester import modelmgr as llm_model_mgr
from ..provider.sysprompt import sysprompt as llm_prompt_mgr from ..provider.sysprompt import sysprompt as llm_prompt_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..platform import manager as im_mgr from ..platform import manager as im_mgr
from ..command import cmdmgr from ..command import cmdmgr
from ..plugin import host as plugin_host from ..plugin import manager as plugin_mgr
from ..utils.center import v2 as center_v2 from ..utils.center import v2 as center_v2
from ..utils import updater from ..utils import version, proxy
from ..utils import context
use_override = False use_override = False
@ -58,7 +56,6 @@ async def make_app() -> app.Application:
"config.py", "config.py",
"config-template.py" "config-template.py"
) )
context.set_config_manager(cfg_mgr)
cfg = cfg_mgr.data cfg = cfg_mgr.data
# 检查是否携带了 --override 或 -r 参数 # 检查是否携带了 --override 或 -r 参数
@ -87,11 +84,20 @@ async def make_app() -> app.Application:
ap.query_pool = pool.QueryPool() ap.query_pool = pool.QueryPool()
proxy_mgr = proxy.ProxyManager(ap)
await proxy_mgr.initialize()
ap.proxy_mgr = proxy_mgr
ver_mgr = version.VersionManager(ap)
await ver_mgr.initialize()
ap.ver_mgr = ver_mgr
center_v2_api = center_v2.V2CenterAPI( center_v2_api = center_v2.V2CenterAPI(
ap,
basic_info={ basic_info={
"host_id": identifier.identifier['host_id'], "host_id": identifier.identifier['host_id'],
"instance_id": identifier.identifier['instance_id'], "instance_id": identifier.identifier['instance_id'],
"semantic_version": updater.get_current_tag(), "semantic_version": ver_mgr.get_current_version(),
"platform": sys.platform, "platform": sys.platform,
}, },
runtime_info={ runtime_info={
@ -99,12 +105,7 @@ async def make_app() -> app.Application:
"msg_source": cfg['msg_source_adapter'], "msg_source": cfg['msg_source_adapter'],
} }
) )
ap.ctr_mgr = center_v2_api # ap.ctr_mgr = center_v2_api
db_mgr_inst = db_mgr.DatabaseManager(ap)
# TODO make it async
db_mgr_inst.initialize_database()
ap.db_mgr = db_mgr_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap) cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize() await cmd_mgr_inst.initialize()
@ -138,7 +139,9 @@ async def make_app() -> app.Application:
ap.ctrl = ctrl ap.ctrl = ctrl
# TODO make it async # TODO make it async
plugin_host.load_plugins() plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst
await ap.initialize() await ap.initialize()

View File

@ -2,10 +2,17 @@ from __future__ import annotations
import enum import enum
import typing import typing
import datetime
import asyncio
import pydantic import pydantic
import mirai import mirai
from ..provider import entities as llm_entities
from ..provider.requester import entities
from ..provider.sysprompt import entities as sysprompt_entities
from ..provider.tools import entities as tools_entities
class LauncherTypes(enum.Enum): class LauncherTypes(enum.Enum):
@ -39,3 +46,43 @@ class Query(pydantic.BaseModel):
resp_message_chain: typing.Optional[mirai.MessageChain] = None resp_message_chain: typing.Optional[mirai.MessageChain] = None
"""回复消息链""" """回复消息链"""
class Conversation(pydantic.BaseModel):
"""对话"""
prompt: sysprompt_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
class Session(pydantic.BaseModel):
"""会话"""
launcher_type: LauncherTypes
launcher_id: int
sender_id: typing.Optional[int] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = []
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
class Config:
arbitrary_types_allowed = True

View File

@ -12,7 +12,6 @@ import func_timeout
from ..provider import session as openai_session from ..provider import session as openai_session
from ..utils import context
import tips as tips_custom import tips as tips_custom
from ..platform import adapter as msadapter from ..platform import adapter as msadapter
from .ratelim import ratelim from .ratelim import ratelim
@ -40,7 +39,7 @@ class QQBotManager:
async def initialize(self): async def initialize(self):
await self.ratelimiter.initialize() await self.ratelimiter.initialize()
config = context.get_config_manager().data config = self.ap.cfg_mgr.data
logging.debug("Use adapter:" + config['msg_source_adapter']) logging.debug("Use adapter:" + config['msg_source_adapter'])
if config['msg_source_adapter'] == 'yirimirai': if config['msg_source_adapter'] == 'yirimirai':
@ -106,7 +105,7 @@ class QQBotManager:
) )
async def send(self, event, msg, check_quote=True, check_at_sender=True): async def send(self, event, msg, check_quote=True, check_at_sender=True):
config = context.get_config_manager().data config = self.ap.cfg_mgr.data
if check_at_sender and config['at_sender']: if check_at_sender and config['at_sender']:
msg.insert( msg.insert(
@ -134,7 +133,7 @@ class QQBotManager:
await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))])) await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
async def notify_admin_message_chain(self, message: mirai.MessageChain): async def notify_admin_message_chain(self, message: mirai.MessageChain):
config = context.get_config_manager().data config = self.ap.cfg_mgr.data
if config['admin_qq'] != 0 and config['admin_qq'] != []: if config['admin_qq'] != 0 and config['admin_qq'] != []:
logging.info("通知管理员:{}".format(message)) logging.info("通知管理员:{}".format(message))

207
pkg/plugin/context.py Normal file
View File

@ -0,0 +1,207 @@
from __future__ import annotations
import typing
import abc
import pydantic
from . import events
from ..provider.tools import entities as tools_entities
from ..core import app
class BasePlugin(metaclass=abc.ABCMeta):
"""插件基类"""
host: APIHost
class APIHost:
"""QChatGPT API 宿主"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
def require_ver(
self,
ge: str,
le: str='v999.999.999',
) -> bool:
"""插件版本要求装饰器
Args:
ge (str): 最低版本要求
le (str, optional): 最高版本要求
Returns:
bool: 是否满足要求, False时为无法获取版本号True时为满足要求报错为不满足要求
"""
qchatgpt_version = ""
try:
qchatgpt_version = self.ap.ver_mgr.get_current_version() # 从updater模块获取版本号
except:
return False
if self.ap.ver_mgr.compare_version_str(qchatgpt_version, ge) < 0 or \
(self.ap.ver_mgr.compare_version_str(qchatgpt_version, le) > 0):
raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{}".format(ge, le, qchatgpt_version))
return True
class EventContext:
"""事件上下文, 保存此次事件运行的信息"""
eid = 0
"""事件编号"""
host: APIHost = None
event: events.BaseEventModel = None
__prevent_default__ = False
"""是否阻止默认行为"""
__prevent_postorder__ = False
"""是否阻止后续插件的执行"""
__return_value__ = {}
""" 返回值
示例:
{
"example": [
'value1',
'value2',
3,
4,
{
'key1': 'value1',
},
['value1', 'value2']
]
}
"""
def add_return(self, key: str, ret):
"""添加返回值"""
if key not in self.__return_value__:
self.__return_value__[key] = []
self.__return_value__[key].append(ret)
def get_return(self, key: str) -> list:
"""获取key的所有返回值"""
if key in self.__return_value__:
return self.__return_value__[key]
return None
def get_return_value(self, key: str):
"""获取key的首个返回值"""
if key in self.__return_value__:
return self.__return_value__[key][0]
return None
def prevent_default(self):
"""阻止默认行为"""
self.__prevent_default__ = True
def prevent_postorder(self):
"""阻止后续插件执行"""
self.__prevent_postorder__ = True
def is_prevented_default(self):
"""是否阻止默认行为"""
return self.__prevent_default__
def is_prevented_postorder(self):
"""是否阻止后序插件执行"""
return self.__prevent_postorder__
def __init__(self, host: APIHost, event: events.BaseEventModel):
self.eid = EventContext.eid
self.host = host
self.event = event
self.__prevent_default__ = False
self.__prevent_postorder__ = False
self.__return_value__ = {}
EventContext.eid += 1
class RuntimeContainer(pydantic.BaseModel):
"""运行时的插件容器
运行期间存储单个插件的信息
"""
plugin_name: str
"""插件名称"""
plugin_description: str
"""插件描述"""
plugin_version: str
"""插件版本"""
plugin_author: str
"""插件作者"""
plugin_source: str
"""插件源码地址"""
main_file: str
"""插件主文件路径"""
pkg_path: str
"""插件包路径"""
plugin_class: typing.Type[BasePlugin] = None
"""插件类"""
enabled: typing.Optional[bool] = True
"""是否启用"""
priority: typing.Optional[int] = 0
"""优先级"""
plugin_inst: typing.Optional[BasePlugin] = None
"""插件实例"""
event_handlers: dict[typing.Type[events.BaseEventModel], typing.Callable[
[BasePlugin, EventContext], typing.Awaitable[None]
]] = {}
"""事件处理器"""
content_functions: list[tools_entities.LLMFunction] = []
"""内容函数"""
class Config:
arbitrary_types_allowed = True
def to_setting_dict(self):
return {
'name': self.plugin_name,
'description': self.plugin_description,
'version': self.plugin_version,
'author': self.plugin_author,
'source': self.plugin_source,
'main_file': self.main_file,
'pkg_path': self.pkg_path,
'priority': self.priority,
'enabled': self.enabled,
}
def set_from_setting_dict(
self,
setting: dict
):
self.plugin_source = setting['source']
self.priority = setting['priority']
self.enabled = setting['enabled']
for function in self.content_functions:
function.enable = self.enabled

24
pkg/plugin/errors.py Normal file
View File

@ -0,0 +1,24 @@
from __future__ import annotations
class PluginSystemError(Exception):
message: str
def __init__(self, message: str):
self.message = message
def __str__(self):
return self.message
class PluginNotFoundError(PluginSystemError):
def __init__(self, message: str):
super().__init__(f"未找到插件: {message}")
class PluginInstallerError(PluginSystemError):
def __init__(self, message: str):
super().__init__(f"安装器操作错误: {message}")

96
pkg/plugin/events.py Normal file
View File

@ -0,0 +1,96 @@
from __future__ import annotations
import typing
import pydantic
import mirai
from . import context
from ..core import entities as core_entities
class BaseEventModel(pydantic.BaseModel):
class Config:
arbitrary_types_allowed = True
class PersonMessageReceived(BaseEventModel):
"""收到任何私聊消息时"""
launcher_type: str
"""发起对象类型(group/person)"""
launcher_id: int
"""发起对象ID(群号/QQ号)"""
sender_id: int
"""发送者ID(QQ号)"""
message_chain: mirai.MessageChain
query: core_entities.Query
"""此次请求的上下文"""
class GroupMessageReceived(BaseEventModel):
"""收到任何群聊消息时"""
launcher_type: str
launcher_id: int
sender_id: int
message_chain: mirai.MessageChain
query: core_entities.Query
"""此次请求的上下文"""
class PersonNormalMessageReceived(BaseEventModel):
"""判断为应该处理的私聊普通消息时触发"""
launcher_type: str
launcher_id: int
sender_id: int
text_message: str
query: core_entities.Query
"""此次请求的上下文"""
alter: typing.Optional[str] = None
"""修改后的消息文本"""
reply: typing.Optional[list] = None
"""回复消息组件列表"""
class PersonCommandSent(BaseEventModel):
"""判断为应该处理的私聊命令时触发"""
launcher_type: str
launcher_id: int
sender_id: int
command: str
params: list[str]
text_message: str
is_admin: bool
query: core_entities.Query
"""此次请求的上下文"""
alter: typing.Optional[str] = None
"""修改后的完整命令文本"""
reply: typing.Optional[list] = None
"""回复消息组件列表"""

View File

@ -1,578 +1,5 @@
# 插件管理模块 from . events import *
import asyncio from . context import EventContext, APIHost as PluginHost
import logging
import importlib
import os
import pkgutil
import sys
import shutil
import traceback
import time
import re
from ..utils import updater as updater def emit(*args, **kwargs):
from ..utils import network as network print('插件调用了已弃用的函数 pkg.plugin.host.emit()')
from ..utils import context as context
from ..plugin import switch as switch
from ..plugin import settings as settings
from ..platform import adapter as msadapter
from ..plugin import metadata as metadata
from mirai import Mirai
import requests
from CallingGPT.session.session import Session
__plugins__ = {}
"""插件列表
示例:
{
"example": {
"path": "plugins/example/main.py",
"enabled: True,
"name": "example",
"description": "example",
"version": "0.0.1",
"author": "RockChinQ",
"class": <class 'plugins.example.ExamplePlugin'>,
"hooks": {
"person_message": [
<function ExamplePlugin.person_message at 0x0000020E1D1B8D38>
]
},
"instance": None
}
}
"""
__plugins_order__ = []
"""插件顺序"""
__enable_content_functions__ = True
"""是否启用内容函数"""
__callable_functions__ = []
"""供GPT调用的函数结构"""
__function_inst_map__: dict[str, callable] = {}
"""函数名:实例 映射"""
def generate_plugin_order():
"""根据__plugin__生成插件初始顺序无视是否启用"""
global __plugins_order__
__plugins_order__ = []
for plugin_name in __plugins__:
__plugins_order__.append(plugin_name)
def iter_plugins():
"""按照顺序迭代插件"""
for plugin_name in __plugins_order__:
if plugin_name not in __plugins__:
continue
yield __plugins__[plugin_name]
def iter_plugins_name():
"""迭代插件名"""
for plugin_name in __plugins_order__:
yield plugin_name
__current_module_path__ = ""
def walk_plugin_path(module, prefix="", path_prefix=""):
global __current_module_path__
"""遍历插件路径"""
for item in pkgutil.iter_modules(module.__path__):
if item.ispkg:
logging.debug("扫描插件包: plugins/{}".format(path_prefix + item.name))
walk_plugin_path(
__import__(module.__name__ + "." + item.name, fromlist=[""]),
prefix + item.name + ".",
path_prefix + item.name + "/",
)
else:
try:
logging.debug(
"扫描插件模块: plugins/{}".format(path_prefix + item.name + ".py")
)
__current_module_path__ = "plugins/" + path_prefix + item.name + ".py"
importlib.import_module(module.__name__ + "." + item.name)
logging.debug(
"加载模块: plugins/{} 成功".format(path_prefix + item.name + ".py")
)
except:
logging.error(
"加载模块: plugins/{} 失败: {}".format(
path_prefix + item.name + ".py", sys.exc_info()
)
)
traceback.print_exc()
def load_plugins():
"""加载插件"""
logging.debug("加载插件")
PluginHost()
walk_plugin_path(__import__("plugins"))
logging.debug(__plugins__)
# 加载开关数据
switch.load_switch()
# 生成初始顺序
generate_plugin_order()
# 加载插件顺序
settings.load_settings()
logging.debug("registered plugins: {}".format(__plugins__))
# 输出已注册的内容函数列表
logging.debug("registered content functions: {}".format(__callable_functions__))
logging.debug("function instance map: {}".format(__function_inst_map__))
# 迁移插件源地址记录
metadata.do_plugin_git_repo_migrate()
def initialize_plugins():
"""初始化插件"""
logging.debug("初始化插件")
import pkg.plugin.models as models
successfully_initialized_plugins = []
for plugin in iter_plugins():
# if not plugin['enabled']:
# continue
try:
models.__current_registering_plugin__ = plugin["name"]
plugin["instance"] = plugin["class"](plugin_host=context.get_plugin_host())
# logging.info("插件 {} 已初始化".format(plugin['name']))
successfully_initialized_plugins.append(plugin["name"])
except:
logging.error("插件{}初始化时发生错误: {}".format(plugin["name"], sys.exc_info()))
logging.debug(traceback.format_exc())
logging.info("以下插件已初始化: {}".format(", ".join(successfully_initialized_plugins)))
def unload_plugins():
"""卸载插件"""
# 不再显式卸载插件,因为当程序结束时,插件的析构函数会被系统执行
# for plugin in __plugins__.values():
# if plugin['enabled'] and plugin['instance'] is not None:
# if not hasattr(plugin['instance'], '__del__'):
# logging.warning("插件{}没有定义析构函数".format(plugin['name']))
# else:
# try:
# plugin['instance'].__del__()
# logging.info("卸载插件: {}".format(plugin['name']))
# plugin['instance'] = None
# except:
# logging.error("插件{}卸载时发生错误: {}".format(plugin['name'], sys.exc_info()))
def get_github_plugin_repo_label(repo_url: str) -> list[str]:
"""获取username, repo"""
# 提取 username/repo , 正则表达式
repo = re.findall(
r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)",
repo_url,
)
if len(repo) > 0: # github
return repo[0].split("/")
else:
return None
def download_plugin_source_code(repo_url: str, target_path: str) -> str:
"""下载插件源码"""
# 检查源类型
# 提取 username/repo , 正则表达式
repo = get_github_plugin_repo_label(repo_url)
target_path += repo[1]
if repo is not None: # github
logging.info("从 GitHub 下载插件源码...")
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
zip_resp = requests.get(
url=zipball_url, proxies=network.wrapper_proxies(), stream=True
)
if zip_resp.status_code != 200:
raise Exception("下载源码失败: {}".format(zip_resp.text))
if os.path.exists("temp/" + target_path):
shutil.rmtree("temp/" + target_path)
if os.path.exists(target_path):
shutil.rmtree(target_path)
os.makedirs("temp/" + target_path)
with open("temp/" + target_path + "/source.zip", "wb") as f:
for chunk in zip_resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
logging.info("下载完成, 解压...")
import zipfile
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
zip_ref.extractall("temp/" + target_path)
os.remove("temp/" + target_path + "/source.zip")
# 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo
import glob
# 获取解压后的文件夹名
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
# 复制到 plugins/repo
shutil.copytree(unzip_dir, target_path + "/")
# 删除解压后的文件夹
shutil.rmtree(unzip_dir)
logging.info("解压完成")
else:
raise Exception("暂不支持的源类型,请使用 GitHub 仓库发行插件。")
return repo[1]
def check_requirements(path: str):
# 检查此目录是否包含requirements.txt
if os.path.exists(path + "/requirements.txt"):
logging.info("检测到requirements.txt正在安装依赖")
import pkg.utils.pkgmgr
pkg.utils.pkgmgr.install_requirements(path + "/requirements.txt")
import pkg.utils.log as log
log.reset_logging()
def install_plugin(repo_url: str):
"""安装插件从git储存库获取并解决依赖"""
repo_label = download_plugin_source_code(repo_url, "plugins/")
check_requirements("plugins/" + repo_label)
metadata.set_plugin_metadata(repo_label, repo_url, int(time.time()), "HEAD")
# 上报安装记录
context.get_center_v2_api().plugin.post_install_record(
plugin={
"name": "unknown",
"remote": repo_url,
"author": "unknown",
"version": "HEAD",
}
)
def uninstall_plugin(plugin_name: str) -> str:
"""卸载插件"""
if plugin_name not in __plugins__:
raise Exception("插件不存在")
plugin_info = get_plugin_info_for_audit(plugin_name)
# 获取文件夹路径
plugin_path = __plugins__[plugin_name]["path"].replace("\\", "/")
# 剪切路径为plugins/插件名
plugin_path = plugin_path.split("plugins/")[1].split("/")[0]
# 删除文件夹
shutil.rmtree("plugins/" + plugin_path)
# 上报卸载记录
context.get_center_v2_api().plugin.post_remove_record(
plugin=plugin_info
)
return "plugins/" + plugin_path
def update_plugin(plugin_name: str):
"""更新插件"""
# 检查是否有远程地址记录
plugin_path_name = get_plugin_path_name_by_plugin_name(plugin_name)
meta = metadata.get_plugin_metadata(plugin_path_name)
if meta == {}:
raise Exception("没有此插件元数据信息,无法更新")
old_plugin_info = get_plugin_info_for_audit(plugin_name)
context.get_center_v2_api().plugin.post_update_record(
plugin=old_plugin_info,
old_version=old_plugin_info['version'],
new_version='HEAD',
)
remote_url = meta["source"]
if (
remote_url == "https://github.com/RockChinQ/QChatGPT"
or remote_url == "https://gitee.com/RockChin/QChatGPT"
or remote_url == ""
or remote_url is None
or remote_url == "http://github.com/RockChinQ/QChatGPT"
or remote_url == "http://gitee.com/RockChin/QChatGPT"
):
raise Exception("插件没有远程地址记录,无法更新")
# 重新安装插件
logging.info("正在重新安装插件以进行更新...")
install_plugin(remote_url)
def get_plugin_name_by_path_name(plugin_path_name: str) -> str:
for k, v in __plugins__.items():
if v["path"] == "plugins/" + plugin_path_name + "/main.py":
return k
return None
def get_plugin_path_name_by_plugin_name(plugin_name: str) -> str:
if plugin_name not in __plugins__:
return None
plugin_main_module_path = __plugins__[plugin_name]["path"]
plugin_main_module_path = plugin_main_module_path.replace("\\", "/")
spt = plugin_main_module_path.split("/")
return spt[1]
def get_plugin_info_for_audit(plugin_name: str) -> dict:
"""获取插件信息"""
if plugin_name not in __plugins__:
return {}
plugin = __plugins__[plugin_name]
name = plugin["name"]
meta = metadata.get_plugin_metadata(get_plugin_path_name_by_plugin_name(name))
remote = meta["source"] if meta != {} else ""
author = plugin["author"]
version = plugin["version"]
return {
"name": name,
"remote": remote,
"author": author,
"version": version,
}
class EventContext:
"""事件上下文"""
eid = 0
"""事件编号"""
name = ""
__prevent_default__ = False
"""是否阻止默认行为"""
__prevent_postorder__ = False
"""是否阻止后续插件的执行"""
__return_value__ = {}
""" 返回值
示例:
{
"example": [
'value1',
'value2',
3,
4,
{
'key1': 'value1',
},
['value1', 'value2']
]
}
"""
def add_return(self, key: str, ret):
"""添加返回值"""
if key not in self.__return_value__:
self.__return_value__[key] = []
self.__return_value__[key].append(ret)
def get_return(self, key: str) -> list:
"""获取key的所有返回值"""
if key in self.__return_value__:
return self.__return_value__[key]
return None
def get_return_value(self, key: str):
"""获取key的首个返回值"""
if key in self.__return_value__:
return self.__return_value__[key][0]
return None
def prevent_default(self):
"""阻止默认行为"""
self.__prevent_default__ = True
def prevent_postorder(self):
"""阻止后续插件执行"""
self.__prevent_postorder__ = True
def is_prevented_default(self):
"""是否阻止默认行为"""
return self.__prevent_default__
def is_prevented_postorder(self):
"""是否阻止后序插件执行"""
return self.__prevent_postorder__
def __init__(self, name: str):
self.name = name
self.eid = EventContext.eid
self.__prevent_default__ = False
self.__prevent_postorder__ = False
self.__return_value__ = {}
EventContext.eid += 1
def emit(event_name: str, **kwargs) -> EventContext:
"""触发事件"""
import pkg.utils.context as context
if context.get_plugin_host() is None:
return None
return context.get_plugin_host().emit(event_name, **kwargs)
class PluginHost:
"""插件宿主"""
def __init__(self):
"""初始化插件宿主"""
context.set_plugin_host(self)
self.calling_gpt_session = Session([])
def get_runtime_context(self) -> context:
"""获取运行时上下文pkg.utils.context模块的对象
此上下文用于和主程序其他模块交互数据库QQ机器人OpenAI接口等
详见pkg.utils.context模块
其中的context变量保存了其他重要模块的类对象可以使用这些对象进行交互
"""
return context
def get_bot(self) -> Mirai:
"""获取机器人对象"""
return context.get_qqbot_manager().bot
def get_bot_adapter(self) -> msadapter.MessageSourceAdapter:
"""获取消息源适配器"""
return context.get_qqbot_manager().adapter
def send_person_message(self, person, message):
"""发送私聊消息"""
self.get_bot_adapter().send_message("person", person, message)
def send_group_message(self, group, message):
"""发送群消息"""
self.get_bot_adapter().send_message("group", group, message)
def notify_admin(self, message):
"""通知管理员"""
context.get_qqbot_manager().notify_admin(message)
def emit(self, event_name: str, **kwargs) -> EventContext:
"""触发事件"""
import json
event_context = EventContext(event_name)
logging.debug("触发事件: {} ({})".format(event_name, event_context.eid))
emitted_plugins = []
for plugin in iter_plugins():
if not plugin["enabled"]:
continue
# if plugin['instance'] is None:
# # 从关闭状态切到开启状态之后,重新加载插件
# try:
# plugin['instance'] = plugin["class"](plugin_host=self)
# logging.info("插件 {} 已初始化".format(plugin['name']))
# except:
# logging.error("插件 {} 初始化时发生错误: {}".format(plugin['name'], sys.exc_info()))
# continue
if "hooks" not in plugin or event_name not in plugin["hooks"]:
continue
emitted_plugins.append(plugin['name'])
hooks = []
if event_name in plugin["hooks"]:
hooks = plugin["hooks"][event_name]
for hook in hooks:
try:
already_prevented_default = event_context.is_prevented_default()
kwargs["host"] = context.get_plugin_host()
kwargs["event"] = event_context
hook(plugin["instance"], **kwargs)
if (
event_context.is_prevented_default()
and not already_prevented_default
):
logging.debug(
"插件 {} 已要求阻止事件 {} 的默认行为".format(plugin["name"], event_name)
)
except Exception as e:
logging.error("插件{}响应事件{}时发生错误".format(plugin["name"], event_name))
logging.error(traceback.format_exc())
# print("done:{}".format(plugin['name']))
if event_context.is_prevented_postorder():
logging.debug("插件 {} 阻止了后序插件的执行".format(plugin["name"]))
break
logging.debug(
"事件 {} ({}) 处理完毕,返回值: {}".format(
event_name, event_context.eid, event_context.__return_value__
)
)
if len(emitted_plugins) > 0:
plugins_info = [get_plugin_info_for_audit(p) for p in emitted_plugins]
context.get_center_v2_api().usage.post_event_record(
plugins=plugins_info,
event_name=event_name,
)
return event_context

45
pkg/plugin/installer.py Normal file
View File

@ -0,0 +1,45 @@
from __future__ import annotations
import typing
import abc
from ..core import app
class PluginInstaller(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def install_plugin(
self,
plugin_source: str,
):
"""安装插件
"""
raise NotImplementedError
@abc.abstractmethod
async def uninstall_plugin(
self,
plugin_name: str,
):
"""卸载插件
"""
raise NotImplementedError
@abc.abstractmethod
async def update_plugin(
self,
plugin_name: str,
plugin_source: str=None,
):
"""更新插件
"""
raise NotImplementedError

View File

@ -0,0 +1,137 @@
from __future__ import annotations
import re
import os
import shutil
import zipfile
import requests
from .. import installer, errors
from ...utils import pkgmgr
class GitHubRepoInstaller(installer.PluginInstaller):
def get_github_plugin_repo_label(self, repo_url: str) -> list[str]:
"""获取username, repo"""
# 提取 username/repo , 正则表达式
repo = re.findall(
r"(?:https?://github\.com/|git@github\.com:)([^/]+/[^/]+?)(?:\.git|/|$)",
repo_url,
)
if len(repo) > 0: # github
return repo[0].split("/")
else:
return None
async def download_plugin_source_code(self, repo_url: str, target_path: str) -> str:
"""下载插件源码"""
# 检查源类型
# 提取 username/repo , 正则表达式
repo = self.get_github_plugin_repo_label(repo_url)
target_path += repo[1]
if repo is not None: # github
self.ap.logger.debug("正在下载源码...")
zipball_url = f"https://api.github.com/repos/{'/'.join(repo)}/zipball/HEAD"
zip_resp = requests.get(
url=zipball_url, proxies=self.ap.proxy_mgr.get_forward_proxies(), stream=True
)
if zip_resp.status_code != 200:
raise Exception("下载源码失败: {}".format(zip_resp.text))
if os.path.exists("temp/" + target_path):
shutil.rmtree("temp/" + target_path)
if os.path.exists(target_path):
shutil.rmtree(target_path)
os.makedirs("temp/" + target_path)
with open("temp/" + target_path + "/source.zip", "wb") as f:
for chunk in zip_resp.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
self.ap.logger.debug("解压中...")
with zipfile.ZipFile("temp/" + target_path + "/source.zip", "r") as zip_ref:
zip_ref.extractall("temp/" + target_path)
os.remove("temp/" + target_path + "/source.zip")
# 目标是 username-repo-hash , 用正则表达式提取完整的文件夹名,复制到 plugins/repo
import glob
# 获取解压后的文件夹名
unzip_dir = glob.glob("temp/" + target_path + "/*")[0]
# 复制到 plugins/repo
shutil.copytree(unzip_dir, target_path + "/")
# 删除解压后的文件夹
shutil.rmtree(unzip_dir)
self.ap.logger.debug("源码下载完成。")
else:
raise errors.PluginInstallerError('仅支持GitHub仓库地址')
return repo[1]
async def install_requirements(self, path: str):
if os.path.exists(path + "/requirements.txt"):
pkgmgr.install_requirements(path + "/requirements.txt")
async def install_plugin(
self,
plugin_source: str,
):
"""安装插件
"""
repo_label = await self.download_plugin_source_code(plugin_source, "plugins/")
await self.install_requirements("plugins/" + repo_label)
await self.ap.plugin_mgr.setting.record_installed_plugin_source(
"plugins/"+repo_label+'/', plugin_source
)
async def uninstall_plugin(
self,
plugin_name: str,
):
"""卸载插件
"""
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is None:
raise errors.PluginInstallerError('插件不存在或未成功加载')
else:
shutil.rmtree(plugin_container.pkg_path)
async def update_plugin(
self,
plugin_name: str,
plugin_source: str=None,
):
"""更新插件
"""
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is None:
raise errors.PluginInstallerError('插件不存在或未成功加载')
else:
if plugin_container.plugin_source:
plugin_source = plugin_container.plugin_source
await self.install_plugin(plugin_source)
else:
raise errors.PluginInstallerError('插件无源码信息,无法更新')

25
pkg/plugin/loader.py Normal file
View File

@ -0,0 +1,25 @@
from __future__ import annotations
from abc import ABCMeta
import typing
import abc
from ..core import app
from . import context, events
class PluginLoader(metaclass=abc.ABCMeta):
"""插件加载器"""
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass
@abc.abstractmethod
async def load_plugins(self) -> list[context.RuntimeContainer]:
pass

View File

View File

@ -0,0 +1,155 @@
from __future__ import annotations
import typing
import pkgutil
import importlib
import traceback
from CallingGPT.entities.namespace import get_func_schema
from .. import loader, events, context, models, host
from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities
class PluginLoader(loader.PluginLoader):
"""加载 plugins/ 目录下的插件"""
_current_pkg_path = ''
_current_module_path = ''
_current_container: context.RuntimeContainer = None
containers: list[context.RuntimeContainer] = []
async def initialize(self):
"""初始化"""
setattr(models, 'register', self.register)
setattr(models, 'on', self.on)
setattr(models, 'func', self.func)
def register(
self,
name: str,
description: str,
version: str,
author: str
) -> typing.Callable[[typing.Type[context.BasePlugin]], typing.Type[context.BasePlugin]]:
self.ap.logger.debug(f'注册插件 {name} {version} by {author}')
container = context.RuntimeContainer(
plugin_name=name,
plugin_description=description,
plugin_version=version,
plugin_author=author,
plugin_source='',
pkg_path=self._current_pkg_path,
main_file=self._current_module_path,
event_handlers={},
content_functions=[],
)
self._current_container = container
def wrapper(cls: context.BasePlugin) -> typing.Type[context.BasePlugin]:
container.plugin_class = cls
return cls
return wrapper
def on(
self,
event: typing.Type[events.BaseEventModel]
) -> typing.Callable[[typing.Callable], typing.Callable]:
"""注册过时的事件处理器"""
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
def wrapper(func: typing.Callable) -> typing.Callable:
async def handler(plugin: context.BasePlugin, ctx: context.EventContext) -> None:
args = {
'host': ctx.host,
'event': ctx,
}
# 把 ctx.event 所有的属性都放到 args 里
for k, v in ctx.event.dict().items():
args[k] = v
await func(plugin, **args)
self._current_container.event_handlers[event] = handler
return func
return wrapper
def func(
self,
name: str=None,
) -> typing.Callable:
"""注册过时的内容函数"""
self.ap.logger.debug(f'注册内容函数 {name}')
def wrapper(func: typing.Callable) -> typing.Callable:
function_schema = get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
async def handler(
query: core_entities.Query,
*args,
**kwargs
):
return func(*args, **kwargs)
llm_function = tools_entities.LLMFunction(
name=function_name,
human_desc='',
description=function_schema['description'],
enable=True,
parameters=function_schema['parameters'],
func=handler,
)
self._current_container.content_functions.append(llm_function)
return func
return wrapper
async def _walk_plugin_path(
self,
module,
prefix='',
path_prefix=''
):
"""遍历插件路径
"""
for item in pkgutil.iter_modules(module.__path__):
if item.ispkg:
await self._walk_plugin_path(
__import__(module.__name__ + "." + item.name, fromlist=[""]),
prefix + item.name + ".",
path_prefix + item.name + "/",
)
else:
try:
self._current_pkg_path = "plugins/" + path_prefix
self._current_module_path = "plugins/" + path_prefix + item.name + ".py"
self._current_container = None
importlib.import_module(module.__name__ + "." + item.name)
if self._current_container is not None:
self.containers.append(self._current_container)
self.ap.logger.debug(f'插件 {self._current_container} 已加载')
except:
self.ap.logger.error(f'加载插件模块 {prefix + item.name} 时发生错误')
traceback.print_exc()
async def load_plugins(self) -> list[context.RuntimeContainer]:
"""加载插件
"""
await self._walk_plugin_path(__import__("plugins", fromlist=[""]))
return self.containers

112
pkg/plugin/manager.py Normal file
View File

@ -0,0 +1,112 @@
from __future__ import annotations
import typing
from ..core import app
from . import context, loader, events, installer, setting, models
from .loaders import legacy
from .installers import github
class PluginManager:
ap: app.Application
loader: loader.PluginLoader
installer: installer.PluginInstaller
setting: setting.SettingManager
api_host: context.APIHost
plugins: list[context.RuntimeContainer]
def __init__(self, ap: app.Application):
self.ap = ap
self.loader = legacy.PluginLoader(ap)
self.installer = github.GitHubRepoInstaller(ap)
self.setting = setting.SettingManager(ap)
self.api_host = context.APIHost(ap)
self.plugins = []
async def initialize(self):
await self.loader.initialize()
await self.installer.initialize()
await self.setting.initialize()
await self.api_host.initialize()
setattr(models, 'require_ver', self.api_host.require_ver)
async def load_plugins(self):
self.plugins = await self.loader.load_plugins()
await self.setting.sync_setting(self.plugins)
# 按优先级倒序
self.plugins.sort(key=lambda x: x.priority, reverse=True)
async def initialize_plugins(self):
pass
async def install_plugin(
self,
plugin_source: str,
):
"""安装插件
"""
await self.installer.install_plugin(plugin_source)
async def uninstall_plugin(
self,
plugin_name: str,
):
"""卸载插件
"""
await self.installer.uninstall_plugin(plugin_name)
async def update_plugin(
self,
plugin_name: str,
plugin_source: str=None,
):
"""更新插件
"""
await self.installer.update_plugin(plugin_name, plugin_source)
def get_plugin_by_name(self, plugin_name: str) -> context.RuntimeContainer:
"""通过插件名获取插件
"""
for plugin in self.plugins:
if plugin.plugin_name == plugin_name:
return plugin
return None
async def emit_event(self, event: events.BaseEventModel) -> context.EventContext:
"""触发事件
"""
ctx = context.EventContext(
host=self.api_host,
event=event
)
for plugin in self.plugins:
if plugin.enabled:
if event.__class__ in plugin.event_handlers:
try:
await plugin.event_handlers[event.__class__](
plugin.plugin_inst,
ctx
)
except Exception as e:
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}')
self.ap.logger.exception(e)
if ctx.is_prevented_postorder():
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了后序插件的执行')
break
self.ap.logger.debug(f'事件 {event.__class__.__name__}({ctx.eid}) 处理完成,返回值 {ctx.__return_value__}')
return ctx

View File

@ -1,87 +0,0 @@
import os
import shutil
import json
import time
import dulwich.errors as dulwich_err
from ..utils import updater
def read_metadata_file() -> dict:
# 读取 plugins/metadata.json 文件
if not os.path.exists('plugins/metadata.json'):
return {}
with open('plugins/metadata.json', 'r') as f:
return json.load(f)
def write_metadata_file(metadata: dict):
if not os.path.exists('plugins'):
os.mkdir('plugins')
with open('plugins/metadata.json', 'w') as f:
json.dump(metadata, f, indent=4, ensure_ascii=False)
def do_plugin_git_repo_migrate():
# 仅在 plugins/metadata.json 不存在时执行
if os.path.exists('plugins/metadata.json'):
return
metadata = read_metadata_file()
# 遍历 plugins 下所有目录获取目录的git远程地址
for plugin_name in os.listdir('plugins'):
plugin_path = os.path.join('plugins', plugin_name)
if not os.path.isdir(plugin_path):
continue
remote_url = None
try:
remote_url = updater.get_remote_url(plugin_path)
except dulwich_err.NotGitRepository:
continue
if remote_url == "https://github.com/RockChinQ/QChatGPT" or remote_url == "https://gitee.com/RockChin/QChatGPT" \
or remote_url == "" or remote_url is None or remote_url == "http://github.com/RockChinQ/QChatGPT" or remote_url == "http://gitee.com/RockChin/QChatGPT":
continue
from . import host
if plugin_name not in metadata:
metadata[plugin_name] = {
'source': remote_url,
'install_timestamp': int(time.time()),
'ref': 'HEAD',
}
write_metadata_file(metadata)
def set_plugin_metadata(
plugin_name: str,
source: str,
install_timestamp: int,
ref: str,
):
metadata = read_metadata_file()
metadata[plugin_name] = {
'source': source,
'install_timestamp': install_timestamp,
'ref': ref,
}
write_metadata_file(metadata)
def remove_plugin_metadata(plugin_name: str):
metadata = read_metadata_file()
if plugin_name in metadata:
del metadata[plugin_name]
write_metadata_file(metadata)
def get_plugin_metadata(plugin_name: str) -> dict:
metadata = read_metadata_file()
if plugin_name in metadata:
return metadata[plugin_name]
return {}

View File

@ -1,299 +1 @@
import logging from .context import BasePlugin as Plugin
from ..plugin import host
from ..utils import context
PersonMessageReceived = "person_message_received"
"""收到私聊消息时,在判断是否应该响应前触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
message_chain: mirai.models.message.MessageChain 消息链
"""
GroupMessageReceived = "group_message_received"
"""收到群聊消息时,在判断是否应该响应前触发(所有群消息)
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
message_chain: mirai.models.message.MessageChain 消息链
"""
PersonNormalMessageReceived = "person_normal_message_received"
"""判断为应该处理的私聊普通消息时触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
text_message: str 消息文本
returns (optional):
alter: str 修改后的消息文本
reply: list 回复消息组件列表
"""
PersonCommandSent = "person_command_sent"
"""判断为应该处理的私聊命令时触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
command: str 命令
params: list[str] 参数列表
text_message: str 完整命令文本
is_admin: bool 是否为管理员
returns (optional):
alter: str 修改后的完整命令文本
reply: list 回复消息组件列表
"""
GroupNormalMessageReceived = "group_normal_message_received"
"""判断为应该处理的群聊普通消息时触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
text_message: str 消息文本
returns (optional):
alter: str 修改后的消息文本
reply: list 回复消息组件列表
"""
GroupCommandSent = "group_command_sent"
"""判断为应该处理的群聊命令时触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
command: str 命令
params: list[str] 参数列表
text_message: str 完整命令文本
is_admin: bool 是否为管理员
returns (optional):
alter: str 修改后的完整命令文本
reply: list 回复消息组件列表
"""
NormalMessageResponded = "normal_message_responded"
"""获取到对普通消息的文字响应时触发
kwargs:
launcher_type: str 发起对象类型(group/person)
launcher_id: int 发起对象ID(群号/QQ号)
sender_id: int 发送者ID(QQ号)
session: pkg.openai.session.Session 会话对象
prefix: str 回复文字消息的前缀
response_text: str 响应文本
finish_reason: str 响应结束原因
funcs_called: list[str] 此次响应中调用的函数列表
returns (optional):
prefix: str 修改后的回复文字消息的前缀
reply: list 替换回复消息组件列表
"""
SessionFirstMessageReceived = "session_first_message_received"
"""会话被第一次交互时触发
kwargs:
session_name: str 会话名称(<launcher_type>_<launcher_id>)
session: pkg.openai.session.Session 会话对象
default_prompt: str 预设值
"""
SessionExplicitReset = "session_reset"
"""会话被用户手动重置时触发,此事件不支持阻止默认行为
kwargs:
session_name: str 会话名称(<launcher_type>_<launcher_id>)
session: pkg.openai.session.Session 会话对象
"""
SessionExpired = "session_expired"
"""会话过期时触发
kwargs:
session_name: str 会话名称(<launcher_type>_<launcher_id>)
session: pkg.openai.session.Session 会话对象
session_expire_time: int 已设置的会话过期时间()
"""
KeyExceeded = "key_exceeded"
"""api-key超额时触发
kwargs:
key_name: str 超额的api-key名称
usage: dict 超额的api-key使用情况
exceeded_keys: list[str] 超额的api-key列表
"""
KeySwitched = "key_switched"
"""api-key超额切换成功时触发此事件不支持阻止默认行为
kwargs:
key_name: str 切换成功的api-key名称
key_list: list[str] api-key列表
"""
PromptPreProcessing = "prompt_pre_processing"
"""每回合调用接口前对prompt进行预处理时触发此事件不支持阻止默认行为
kwargs:
session_name: str 会话名称(<launcher_type>_<launcher_id>)
default_prompt: list 此session使用的情景预设内容
prompt: list 此session现有的prompt内容
text_message: str 用户发送的消息文本
returns (optional):
default_prompt: list 修改后的情景预设内容
prompt: list 修改后的prompt内容
text_message: str 修改后的消息文本
"""
def on(*args, **kwargs):
"""注册事件监听器
"""
return Plugin.on(*args, **kwargs)
def func(*args, **kwargs):
"""注册内容函数声明此函数为一个内容函数在对话中将发送此函数给GPT以供其调用
此函数可以具有任意的参数但必须按照[此文档](https://github.com/RockChinQ/CallingGPT/wiki/1.-Function-Format#function-format)
所述的格式编写函数的docstring
此功能仅支持在使用gpt-3.5或gpt-4系列模型时使用
"""
return Plugin.func(*args, **kwargs)
__current_registering_plugin__ = ""
def require_ver(ge: str, le: str="v999.9.9") -> bool:
"""插件版本要求装饰器
Args:
ge (str): 最低版本要求
le (str, optional): 最高版本要求
Returns:
bool: 是否满足要求, False时为无法获取版本号True时为满足要求报错为不满足要求
"""
qchatgpt_version = ""
from pkg.utils.updater import get_current_tag, compare_version_str
try:
qchatgpt_version = get_current_tag() # 从updater模块获取版本号
except:
return False
if compare_version_str(qchatgpt_version, ge) < 0 or \
(compare_version_str(qchatgpt_version, le) > 0):
raise Exception("QChatGPT 版本不满足要求,某些功能(可能是由插件提供的)无法正常使用。(要求版本:{}-{},但当前版本:{}".format(ge, le, qchatgpt_version))
return True
class Plugin:
"""插件基类"""
host: host.PluginHost
"""插件宿主,提供插件的一些基础功能"""
@classmethod
def on(cls, event):
"""事件处理器装饰器
:param
event: 事件类型
:return:
None
"""
global __current_registering_plugin__
def wrapper(func):
plugin_hooks = host.__plugins__[__current_registering_plugin__]["hooks"]
if event not in plugin_hooks:
plugin_hooks[event] = []
plugin_hooks[event].append(func)
# print("registering hook: p='{}', e='{}', f={}".format(__current_registering_plugin__, event, func))
host.__plugins__[__current_registering_plugin__]["hooks"] = plugin_hooks
return func
return wrapper
@classmethod
def func(cls, name: str=None):
"""内容函数装饰器
"""
global __current_registering_plugin__
from CallingGPT.entities.namespace import get_func_schema
def wrapper(func):
function_schema = get_func_schema(func)
function_schema['name'] = __current_registering_plugin__ + '-' + (func.__name__ if name is None else name)
function_schema['enabled'] = True
host.__function_inst_map__[function_schema['name']] = function_schema['function']
del function_schema['function']
# logging.debug("registering content function: p='{}', f='{}', s={}".format(__current_registering_plugin__, func, function_schema))
host.__callable_functions__.append(
function_schema
)
return func
return wrapper
def register(name: str, description: str, version: str, author: str):
"""注册插件, 此函数作为装饰器使用
Args:
name (str): 插件名称
description (str): 插件描述
version (str): 插件版本
author (str): 插件作者
Returns:
None
"""
global __current_registering_plugin__
__current_registering_plugin__ = name
# print("registering plugin: n='{}', d='{}', v={}, a='{}'".format(name, description, version, author))
host.__plugins__[name] = {
"name": name,
"description": description,
"version": version,
"author": author,
"hooks": {},
"path": host.__current_module_path__,
"enabled": True,
"instance": None,
}
def wrapper(cls: Plugin):
cls.name = name
cls.description = description
cls.version = version
cls.author = author
cls.host = context.get_plugin_host()
cls.enabled = True
cls.path = host.__current_module_path__
# 存到插件列表
host.__plugins__[name]["class"] = cls
logging.info("插件注册完成: n='{}', d='{}', v={}, a='{}' ({})".format(name, description, version, author, cls))
return cls
return wrapper

83
pkg/plugin/setting.py Normal file
View File

@ -0,0 +1,83 @@
from __future__ import annotations
from ..core import app
from ..config import manager as cfg_mgr
from . import context
class SettingManager:
ap: app.Application
settings: cfg_mgr.ConfigManager
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
self.settings = await cfg_mgr.load_json_config(
'plugins/plugins.json',
'res/templates/plugin-setting-template.json'
)
async def sync_setting(
self,
plugin_containers: list[context.RuntimeContainer],
):
"""同步设置
"""
not_matched_source_record = []
for value in self.settings.data['plugins']:
if 'name' not in value: # 只有远程地址的应用到pkg_path相同的插件容器上
matched = False
for plugin_container in plugin_containers:
if plugin_container.pkg_path == value['pkg_path']:
matched = True
plugin_container.plugin_source = value['source']
break
if not matched:
not_matched_source_record.append(value)
else: # 正常的插件设置
for plugin_container in plugin_containers:
if plugin_container.plugin_name == value['name']:
plugin_container.set_from_setting_dict(value)
self.settings.data = {
'plugins': [
p.to_setting_dict()
for p in plugin_containers
]
}
self.settings.data['plugins'].extend(not_matched_source_record)
await self.settings.dump_config()
async def record_installed_plugin_source(
self,
pkg_path: str,
source: str
):
found = False
for value in self.settings.data['plugins']:
if value['pkg_path'] == pkg_path:
value['source'] = source
found = True
break
if not found:
self.settings.data['plugins'].append(
{
'pkg_path': pkg_path,
'source': source
}
)
await self.settings.dump_config()

View File

@ -1,103 +0,0 @@
import json
import os
import logging
from ..plugin import host
def wrapper_dict_from_runtime_context() -> dict:
"""从变量中包装settings.json的数据字典"""
settings = {
"order": [],
"functions": {
"enabled": host.__enable_content_functions__
}
}
for plugin_name in host.__plugins_order__:
settings["order"].append(plugin_name)
return settings
def apply_settings(settings: dict):
"""将settings.json数据应用到变量中"""
if "order" in settings:
host.__plugins_order__ = settings["order"]
if "functions" in settings:
if "enabled" in settings["functions"]:
host.__enable_content_functions__ = settings["functions"]["enabled"]
# logging.debug("set content function enabled: {}".format(host.__enable_content_functions__))
def dump_settings():
"""保存settings.json数据"""
logging.debug("保存plugins/settings.json数据")
settings = wrapper_dict_from_runtime_context()
with open("plugins/settings.json", "w", encoding="utf-8") as f:
json.dump(settings, f, indent=4, ensure_ascii=False)
def load_settings():
"""加载settings.json数据"""
logging.debug("加载plugins/settings.json数据")
# 读取plugins/settings.json
settings = {
}
# 检查文件是否存在
if not os.path.exists("plugins/settings.json"):
# 不存在则创建
with open("plugins/settings.json", "w", encoding="utf-8") as f:
json.dump(wrapper_dict_from_runtime_context(), f, indent=4, ensure_ascii=False)
with open("plugins/settings.json", "r", encoding="utf-8") as f:
settings = json.load(f)
if settings is None:
settings = {
}
# 检查每个设置项
if "order" not in settings:
settings["order"] = []
settings_modified = False
settings_copy = settings.copy()
# 检查settings中多余的插件项
# order
for plugin_name in settings_copy["order"]:
if plugin_name not in host.__plugins_order__:
settings["order"].remove(plugin_name)
settings_modified = True
# 检查settings中缺少的插件项
# order
for plugin_name in host.__plugins_order__:
if plugin_name not in settings_copy["order"]:
settings["order"].append(plugin_name)
settings_modified = True
if "functions" not in settings:
settings["functions"] = {
"enabled": host.__enable_content_functions__
}
settings_modified = True
elif "enabled" not in settings["functions"]:
settings["functions"]["enabled"] = host.__enable_content_functions__
settings_modified = True
logging.info("已全局{}内容函数。".format("启用" if settings["functions"]["enabled"] else "禁用"))
apply_settings(settings)
if settings_modified:
dump_settings()

View File

@ -1,94 +0,0 @@
# 控制插件的开关
import json
import logging
import os
from ..plugin import host
def wrapper_dict_from_plugin_list() -> dict:
"""将插件列表转换为开关json"""
switch = {}
for plugin_name in host.__plugins__:
plugin = host.__plugins__[plugin_name]
switch[plugin_name] = {
"path": plugin["path"],
"enabled": plugin["enabled"],
}
return switch
def apply_switch(switch: dict):
"""将开关数据应用到插件列表中"""
# print("将开关数据应用到插件列表中")
# print(switch)
for plugin_name in switch:
host.__plugins__[plugin_name]["enabled"] = switch[plugin_name]["enabled"]
# 查找此插件的所有内容函数
for func in host.__callable_functions__:
if func['name'].startswith(plugin_name + '-'):
func['enabled'] = switch[plugin_name]["enabled"]
def dump_switch():
"""保存开关数据"""
logging.debug("保存开关数据")
# 将开关数据写入plugins/switch.json
switch = wrapper_dict_from_plugin_list()
with open("plugins/switch.json", "w", encoding="utf-8") as f:
json.dump(switch, f, indent=4, ensure_ascii=False)
def load_switch():
"""加载开关数据"""
logging.debug("加载开关数据")
# 读取plugins/switch.json
switch = {}
# 检查文件是否存在
if not os.path.exists("plugins/switch.json"):
# 不存在则创建
with open("plugins/switch.json", "w", encoding="utf-8") as f:
json.dump(switch, f, indent=4, ensure_ascii=False)
with open("plugins/switch.json", "r", encoding="utf-8") as f:
switch = json.load(f)
if switch is None:
switch = {}
switch_modified = False
switch_copy = switch.copy()
# 检查switch中多余的和path不相符的
for plugin_name in switch_copy:
if plugin_name not in host.__plugins__:
del switch[plugin_name]
switch_modified = True
elif switch[plugin_name]["path"] != host.__plugins__[plugin_name]["path"]:
# 删除此不相符的
del switch[plugin_name]
switch_modified = True
# 检查plugin中多余的
for plugin_name in host.__plugins__:
if plugin_name not in switch:
switch[plugin_name] = {
"path": host.__plugins__[plugin_name]["path"],
"enabled": host.__plugins__[plugin_name]["enabled"],
}
switch_modified = True
# 应用开关数据
apply_switch(switch)
# 如果switch有修改保存
if switch_modified:
dump_switch()

View File

@ -1,232 +0,0 @@
import json
import logging
import openai
from openai.types.chat import chat_completion_message
from .model import RequestBase
from .. import funcmgr
from ...plugin import host
from ...utils import context
class ChatCompletionRequest(RequestBase):
"""调用ChatCompletion接口的请求类。
此类保证每一次返回的角色为assistant的信息的finish_reason一定为stop
若有函数调用响应本类的返回瀑布是函数调用请求->函数调用结果->...->assistant的信息->stop
"""
model: str
messages: list[dict[str, str]]
kwargs: dict
stopped: bool = False
pending_func_call: chat_completion_message.FunctionCall = None
pending_msg: str
def flush_pending_msg(self):
self.append_message(
role="assistant",
content=self.pending_msg
)
self.pending_msg = ""
def append_message(self, role: str, content: str, name: str=None, function_call: dict=None):
msg = {
"role": role,
"content": content
}
if name is not None:
msg['name'] = name
if function_call is not None:
msg['function_call'] = function_call
self.messages.append(msg)
def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.messages = messages.copy()
self.kwargs = kwargs
self.req_func = self.client.chat.completions.create
self.pending_func_call = None
self.stopped = False
self.pending_msg = ""
def __iter__(self):
return self
def __next__(self) -> dict:
if self.stopped:
raise StopIteration()
if self.pending_func_call is None: # 没有待处理的函数调用请求
args = {
"model": self.model,
"messages": self.messages,
}
funcs = funcmgr.get_func_schema_list()
if len(funcs) > 0:
args['functions'] = funcs
# 拼接kwargs
args = {**args, **self.kwargs}
from openai.types.chat import chat_completion
resp: chat_completion.ChatCompletion = self._req(**args)
choice0 = resp.choices[0]
# 如果不是函数调用且finish_reason为stop则停止迭代
if choice0.finish_reason == 'stop': # and choice0["finish_reason"] == "stop"
self.stopped = True
if hasattr(choice0.message, 'function_call') and choice0.message.function_call is not None:
self.pending_func_call = choice0.message.function_call
self.append_message(
role="assistant",
content=choice0.message.content,
function_call=choice0.message.function_call
)
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "function_call",
"content": choice0.message.content,
"function_call": {
"name": choice0.message.function_call.name,
"arguments": choice0.message.function_call.arguments
}
},
"finish_reason": "function_call"
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else:
# self.pending_msg += choice0['message']['content']
# 普通回复一定处于最后方故不用再追加进内部messages
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0.message.content
},
"finish_reason": choice0.finish_reason
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}
else: # 处理函数调用请求
cp_pending_func_call = self.pending_func_call.copy()
self.pending_func_call = None
func_name = cp_pending_func_call.name
arguments = {}
try:
try:
arguments = json.loads(cp_pending_func_call.arguments)
# 若不是json格式的异常处理
except json.decoder.JSONDecodeError:
# 获取函数的参数列表
func_schema = funcmgr.get_func_schema(func_name)
arguments = {
func_schema['parameters']['required'][0]: cp_pending_func_call.arguments
}
logging.info("执行函数调用: name={}, arguments={}".format(func_name, arguments))
# 执行函数调用
ret = ""
try:
ret = funcmgr.execute_function(func_name, arguments)
logging.info("函数执行完成。")
except Exception as e:
ret = "error: execute function failed: {}".format(str(e))
logging.error("函数执行失败: {}".format(str(e)))
# 上报数据
plugin_info = host.get_plugin_info_for_audit(func_name.split('-')[0])
audit_func_name = func_name.split('-')[1]
audit_func_desc = funcmgr.get_func_schema(func_name)['description']
context.get_center_v2_api().usage.post_function_record(
plugin=plugin_info,
function_name=audit_func_name,
function_description=audit_func_desc,
)
self.append_message(
role="function",
content=json.dumps(ret, ensure_ascii=False),
name=func_name
)
return {
"id": -1,
"choices": [
{
"index": -1,
"message": {
"role": "function",
"type": "function_return",
"function_name": func_name,
"content": json.dumps(ret, ensure_ascii=False)
},
"finish_reason": "function_return"
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
except funcmgr.ContentFunctionNotFoundError:
raise Exception("没有找到函数: {}".format(func_name))

View File

@ -1,100 +0,0 @@
import openai
from openai.types import completion, completion_choice
from . import model
class CompletionRequest(model.RequestBase):
"""调用Completion接口的请求类。
调用方可以一直next completion直到finish_reason为stop
"""
model: str
prompt: str
kwargs: dict
stopped: bool = False
def __init__(
self,
client: openai.Client,
model: str,
messages: list[dict[str, str]],
**kwargs
):
self.client = client
self.model = model
self.prompt = ""
for message in messages:
self.prompt += message["role"] + ": " + message["content"] + "\n"
self.prompt += "assistant: "
self.kwargs = kwargs
self.req_func = self.client.completions.create
def __iter__(self):
return self
def __next__(self) -> dict:
"""调用Completion接口返回生成的文本
{
"id": "id",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"type": "text",
"content": "message"
},
"finish_reason": "reason"
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30
}
}
"""
if self.stopped:
raise StopIteration()
resp: completion.Completion = self._req(
model=self.model,
prompt=self.prompt,
**self.kwargs
)
if resp.choices[0].finish_reason == "stop":
self.stopped = True
choice0: completion_choice.CompletionChoice = resp.choices[0]
self.prompt += choice0.text
return {
"id": resp.id,
"choices": [
{
"index": choice0.index,
"message": {
"role": "assistant",
"type": "text",
"content": choice0.text
},
"finish_reason": choice0.finish_reason
}
],
"usage": {
"prompt_tokens": resp.usage.prompt_tokens,
"completion_tokens": resp.usage.completion_tokens,
"total_tokens": resp.usage.total_tokens
}
}

View File

@ -1,40 +0,0 @@
# 定义不同接口请求的模型
import logging
import openai
from ...utils import context
class RequestBase:
client: openai.Client
req_func: callable
def __init__(self, *args, **kwargs):
raise NotImplementedError
def _next_key(self):
switched, name = context.get_openai_manager().key_mgr.auto_switch()
logging.debug("切换api-key: switched={}, name={}".format(switched, name))
self.client.api_key = context.get_openai_manager().key_mgr.get_using_key()
def _req(self, **kwargs):
"""处理代理问题"""
logging.debug("请求接口参数: %s", str(kwargs))
config = context.get_config_manager().data
ret = self.req_func(**kwargs)
logging.debug("接口请求返回:%s", str(ret))
if config['switch_strategy'] == 'active':
self._next_key()
return ret
def __iter__(self):
raise self
def __next__(self):
raise NotImplementedError

View File

@ -6,7 +6,6 @@ import typing
from ...core import app from ...core import app
from ...core import entities as core_entities from ...core import entities as core_entities
from .. import entities as llm_entities from .. import entities as llm_entities
from ..session import entities as session_entities
class LLMAPIRequester(metaclass=abc.ABCMeta): class LLMAPIRequester(metaclass=abc.ABCMeta):
"""LLM API请求器 """LLM API请求器
@ -24,7 +23,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
async def request( async def request(
self, self,
query: core_entities.Query, query: core_entities.Query,
conversation: session_entities.Conversation, conversation: core_entities.Conversation,
) -> typing.AsyncGenerator[llm_entities.Message, None]: ) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求 """请求
""" """

View File

@ -10,7 +10,6 @@ import openai.types.chat.chat_completion as chat_completion
from .. import api from .. import api
from ....core import entities as core_entities from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
from ...session import entities as session_entities
class OpenAIChatCompletion(api.LLMAPIRequester): class OpenAIChatCompletion(api.LLMAPIRequester):
@ -43,41 +42,18 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
async def _closure( async def _closure(
self, self,
req_messages: list[dict], req_messages: list[dict],
conversation: session_entities.Conversation, conversation: core_entities.Conversation,
user_text: str = None,
function_ret: str = None,
) -> llm_entities.Message: ) -> llm_entities.Message:
self.client.api_key = conversation.use_model.token_mgr.get_token() self.client.api_key = conversation.use_model.token_mgr.get_token()
args = self.ap.cfg_mgr.data["completion_api_params"].copy() args = self.ap.cfg_mgr.data["completion_api_params"].copy()
args["model"] = conversation.use_model.name args["model"] = conversation.use_model.name
tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation) if conversation.use_model.tool_call_supported:
# tools = [ tools = await self.ap.tool_mgr.generate_tools_for_openai(conversation)
# {
# "type": "function", if tools:
# "function": { args["tools"] = tools
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA",
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"],
# },
# },
# "required": ["location"],
# },
# },
# }
# ]
if tools:
args["tools"] = tools
# 设置此次请求中的messages # 设置此次请求中的messages
messages = req_messages messages = req_messages
@ -92,7 +68,7 @@ class OpenAIChatCompletion(api.LLMAPIRequester):
return message return message
async def request( async def request(
self, query: core_entities.Query, conversation: session_entities.Conversation self, query: core_entities.Query, conversation: core_entities.Conversation
) -> typing.AsyncGenerator[llm_entities.Message, None]: ) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""请求""" """请求"""

View File

@ -1,9 +1,11 @@
from __future__ import annotations
import typing import typing
import pydantic import pydantic
from . import api from . import api
from . import token from . import token, tokenizer
class LLMModelInfo(pydantic.BaseModel): class LLMModelInfo(pydantic.BaseModel):
@ -17,7 +19,9 @@ class LLMModelInfo(pydantic.BaseModel):
requester: api.LLMAPIRequester requester: api.LLMAPIRequester
function_call_supported: typing.Optional[bool] = False tokenizer: 'tokenizer.LLMTokenizer'
tool_call_supported: typing.Optional[bool] = False
class Config: class Config:
arbitrary_types_allowed = True arbitrary_types_allowed = True

View File

@ -5,6 +5,7 @@ from ...core import app
from .apis import chatcmpl from .apis import chatcmpl
from . import token from . import token
from .tokenizers import tiktoken
class ModelManager: class ModelManager:
@ -17,25 +18,28 @@ class ModelManager:
self.ap = ap self.ap = ap
self.model_list = [] self.model_list = []
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")
async def initialize(self): async def initialize(self):
openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap) openai_chat_completion = chatcmpl.OpenAIChatCompletion(self.ap)
await openai_chat_completion.initialize() await openai_chat_completion.initialize()
openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values())) openai_token_mgr = token.TokenManager(self.ap, list(self.ap.cfg_mgr.data['openai_config']['api_key'].values()))
tiktoken_tokenizer = tiktoken.Tiktoken(self.ap)
self.model_list.append( self.model_list.append(
entities.LLMModelInfo( entities.LLMModelInfo(
name="gpt-3.5-turbo", name="gpt-3.5-turbo",
provider="openai", provider="openai",
token_mgr=openai_token_mgr, token_mgr=openai_token_mgr,
requester=openai_chat_completion, requester=openai_chat_completion,
function_call_supported=True tool_call_supported=True,
tokenizer=tiktoken_tokenizer
) )
) )
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
"""通过名称获取模型
"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f"Model {name} not found")

View File

@ -0,0 +1,29 @@
from __future__ import annotations
import abc
import typing
from ...core import app
from .. import entities as llm_entities
from . import entities
class LLMTokenizer(metaclass=abc.ABCMeta):
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
"""初始化分词器
"""
pass
@abc.abstractmethod
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
pass

View File

@ -0,0 +1,28 @@
from __future__ import annotations
import tiktoken
from .. import tokenizer
from ... import entities as llm_entities
from .. import entities
class Tiktoken(tokenizer.LLMTokenizer):
async def count_token(
self,
messages: list[llm_entities.Message],
model: entities.LLMModelInfo
) -> int:
try:
encoding = tiktoken.encoding_for_model(model.name)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
num_tokens += len(encoding.encode(message.role))
num_tokens += len(encoding.encode(message.content))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens

View File

@ -1,53 +0,0 @@
from __future__ import annotations
import datetime
import asyncio
import typing
import pydantic
from ..sysprompt import entities as sysprompt_entities
from .. import entities as llm_entities
from ..requester import entities
from ...core import entities as core_entities
from ..tools import entities as tools_entities
class Conversation(pydantic.BaseModel):
"""对话"""
prompt: sysprompt_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_model: entities.LLMModelInfo
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
class Session(pydantic.BaseModel):
"""会话"""
launcher_type: core_entities.LauncherTypes
launcher_id: int
sender_id: typing.Optional[int] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = []
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
class Config:
arbitrary_types_allowed = True

View File

@ -3,14 +3,13 @@ from __future__ import annotations
import asyncio import asyncio
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
from . import entities
class SessionManager: class SessionManager:
ap: app.Application ap: app.Application
session_list: list[entities.Session] session_list: list[core_entities.Session]
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
@ -19,14 +18,14 @@ class SessionManager:
async def initialize(self): async def initialize(self):
pass pass
async def get_session(self, query: core_entities.Query) -> entities.Session: async def get_session(self, query: core_entities.Query) -> core_entities.Session:
"""获取会话 """获取会话
""" """
for session in self.session_list: for session in self.session_list:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id: if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
return session return session
session = entities.Session( session = core_entities.Session(
launcher_type=query.launcher_type, launcher_type=query.launcher_type,
launcher_id=query.launcher_id, launcher_id=query.launcher_id,
semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000), semaphore=asyncio.Semaphore(1) if self.ap.cfg_mgr.data['wait_last_done'] else asyncio.Semaphore(10000),
@ -34,12 +33,12 @@ class SessionManager:
self.session_list.append(session) self.session_list.append(session)
return session return session
async def get_conversation(self, session: entities.Session) -> entities.Conversation: async def get_conversation(self, session: core_entities.Session) -> core_entities.Conversation:
if not session.conversations: if not session.conversations:
session.conversations = [] session.conversations = []
if session.using_conversation is None: if session.using_conversation is None:
conversation = entities.Conversation( conversation = core_entities.Conversation(
prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name), prompt=await self.ap.prompt_mgr.get_prompt(session.use_prompt_name),
messages=[], messages=[],
use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']), use_model=await self.ap.model_mgr.get_model_by_name(self.ap.cfg_mgr.data['completion_api_params']['model']),

View File

@ -6,6 +6,8 @@ import asyncio
import pydantic import pydantic
from ...core import entities as core_entities
class LLMFunction(pydantic.BaseModel): class LLMFunction(pydantic.BaseModel):
"""函数""" """函数"""

View File

@ -4,7 +4,6 @@ import typing
from ...core import app, entities as core_entities from ...core import app, entities as core_entities
from . import entities from . import entities
from ..session import entities as session_entities
class ToolManager: class ToolManager:
@ -12,8 +11,6 @@ class ToolManager:
""" """
ap: app.Application ap: app.Application
all_functions: list[entities.LLMFunction]
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
self.ap = ap self.ap = ap
@ -22,30 +19,10 @@ class ToolManager:
async def initialize(self): async def initialize(self):
pass pass
def register_legacy_function(self, name: str, description: str, parameters: dict, func: callable):
"""注册函数
"""
async def wrapper(query, **kwargs):
return func(**kwargs)
function = entities.LLMFunction(
name=name,
description=description,
human_desc='',
enable=True,
parameters=parameters,
func=wrapper
)
self.all_functions.append(function)
async def register_function(self, function: entities.LLMFunction):
"""添加函数
"""
self.all_functions.append(function)
async def get_function(self, name: str) -> entities.LLMFunction: async def get_function(self, name: str) -> entities.LLMFunction:
"""获取函数 """获取函数
""" """
for function in self.all_functions: for function in await self.get_all_functions():
if function.name == name: if function.name == name:
return function return function
return None return None
@ -53,9 +30,14 @@ class ToolManager:
async def get_all_functions(self) -> list[entities.LLMFunction]: async def get_all_functions(self) -> list[entities.LLMFunction]:
"""获取所有函数 """获取所有函数
""" """
return self.all_functions all_functions: list[entities.LLMFunction] = []
async def generate_tools_for_openai(self, conversation: session_entities.Conversation) -> str: for plugin in self.ap.plugin_mgr.plugins:
all_functions.extend(plugin.content_functions)
return all_functions
async def generate_tools_for_openai(self, conversation: core_entities.Conversation) -> str:
"""生成函数列表 """生成函数列表
""" """
tools = [] tools = []

View File

@ -1,17 +1,20 @@
from __future__ import annotations from __future__ import annotations
from .. import apigroup from .. import apigroup
from ... import context from ....core import app
class V2MainDataAPI(apigroup.APIGroup): class V2MainDataAPI(apigroup.APIGroup):
"""主程序相关 数据API""" """主程序相关 数据API"""
def __init__(self, prefix: str): ap: app.Application
super().__init__(prefix+"/main")
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage")
def do(self, *args, **kwargs): def do(self, *args, **kwargs):
config = context.get_config_manager().data config = self.ap.cfg_mgr.data
if not config['report_usage']: if not config['report_usage']:
return None return None
return super().do(*args, **kwargs) return super().do(*args, **kwargs)

View File

@ -1,17 +1,20 @@
from __future__ import annotations from __future__ import annotations
from ....core import app
from .. import apigroup from .. import apigroup
from ... import context
class V2PluginDataAPI(apigroup.APIGroup): class V2PluginDataAPI(apigroup.APIGroup):
"""插件数据相关 API""" """插件数据相关 API"""
def __init__(self, prefix: str): ap: app.Application
super().__init__(prefix+"/plugin")
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage")
def do(self, *args, **kwargs): def do(self, *args, **kwargs):
config = context.get_config_manager().data config = self.ap.cfg_mgr.data
if not config['report_usage']: if not config['report_usage']:
return None return None
return super().do(*args, **kwargs) return super().do(*args, **kwargs)

View File

@ -1,17 +1,20 @@
from __future__ import annotations from __future__ import annotations
from .. import apigroup from .. import apigroup
from ... import context from ....core import app
class V2UsageDataAPI(apigroup.APIGroup): class V2UsageDataAPI(apigroup.APIGroup):
"""使用量数据相关 API""" """使用量数据相关 API"""
def __init__(self, prefix: str): ap: app.Application
def __init__(self, prefix: str, ap: app.Application):
self.ap = ap
super().__init__(prefix+"/usage") super().__init__(prefix+"/usage")
def do(self, *args, **kwargs): def do(self, *args, **kwargs):
config = context.get_config_manager().data config = self.ap.cfg_mgr.data
if not config['report_usage']: if not config['report_usage']:
return None return None
return super().do(*args, **kwargs) return super().do(*args, **kwargs)

View File

@ -6,7 +6,7 @@ from . import apigroup
from .groups import main from .groups import main
from .groups import usage from .groups import usage
from .groups import plugin from .groups import plugin
from ...utils import context from ...core import app
BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2" BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2"
@ -23,7 +23,7 @@ class V2CenterAPI:
plugin: plugin.V2PluginDataAPI = None plugin: plugin.V2PluginDataAPI = None
"""插件 API 组""" """插件 API 组"""
def __init__(self, basic_info: dict = None, runtime_info: dict = None): def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None):
"""初始化""" """初始化"""
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info) logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
@ -31,8 +31,7 @@ class V2CenterAPI:
apigroup.APIGroup._basic_info = basic_info apigroup.APIGroup._basic_info = basic_info
apigroup.APIGroup._runtime_info = runtime_info apigroup.APIGroup._runtime_info = runtime_info
self.main = main.V2MainDataAPI(BACKEND_URL) self.main = main.V2MainDataAPI(BACKEND_URL, ap)
self.usage = usage.V2UsageDataAPI(BACKEND_URL) self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap)
self.plugin = plugin.V2PluginDataAPI(BACKEND_URL) self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap)
context.set_center_v2_api(self)

View File

@ -1,11 +0,0 @@
from . import context
def wrapper_proxies() -> dict:
"""获取代理"""
config = context.get_config_manager().data
return {
"http": config['openai_config']['proxy'],
"https": config['openai_config']['proxy']
} if 'proxy' in config['openai_config'] and (config['openai_config']['proxy'] is not None) else None

View File

@ -1,27 +1,27 @@
from pip._internal import main as pipmain from pip._internal import main as pipmain
from . import log # from . import log
def install(package): def install(package):
pipmain(['install', package]) pipmain(['install', package])
log.reset_logging() # log.reset_logging()
def install_upgrade(package): def install_upgrade(package):
pipmain(['install', '--upgrade', package, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", pipmain(['install', '--upgrade', package, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
"--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) "--trusted-host", "pypi.tuna.tsinghua.edu.cn"])
log.reset_logging() # log.reset_logging()
def run_pip(params: list): def run_pip(params: list):
pipmain(params) pipmain(params)
log.reset_logging() # log.reset_logging()
def install_requirements(file): def install_requirements(file):
pipmain(['install', '-r', file, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple", pipmain(['install', '-r', file, "-i", "https://pypi.tuna.tsinghua.edu.cn/simple",
"--trusted-host", "pypi.tuna.tsinghua.edu.cn"]) "--trusted-host", "pypi.tuna.tsinghua.edu.cn"])
log.reset_logging() # log.reset_logging()
def ensure_dulwich(): def ensure_dulwich():

30
pkg/utils/proxy.py Normal file
View File

@ -0,0 +1,30 @@
from __future__ import annotations
from ..core import app
class ProxyManager:
ap: app.Application
forward_proxies: dict[str, str]
def __init__(self, ap: app.Application):
self.ap = ap
self.forward_proxies = {}
async def initialize(self):
config = self.ap.cfg_mgr.data
return (
{
"http": config["openai_config"]["proxy"],
"https": config["openai_config"]["proxy"],
}
if "proxy" in config["openai_config"]
and (config["openai_config"]["proxy"] is not None)
else None
)
def get_forward_proxies(self) -> str:
return self.forward_proxies

View File

@ -8,21 +8,6 @@ import time
import requests import requests
from . import constants from . import constants
from . import network
from . import context
def check_dulwich_closure():
try:
import pkg.utils.pkgmgr
pkg.utils.pkgmgr.ensure_dulwich()
except:
pass
try:
import dulwich
except ModuleNotFoundError:
raise Exception("dulwich模块未安装,请查看 https://github.com/RockChinQ/QChatGPT/issues/77")
def is_newer(new_tag: str, old_tag: str): def is_newer(new_tag: str, old_tag: str):
@ -47,28 +32,6 @@ def is_newer(new_tag: str, old_tag: str):
return new_tag != old_tag return new_tag != old_tag
def get_release_list() -> list:
"""获取发行列表"""
rls_list_resp = requests.get(
url="https://api.github.com/repos/RockChinQ/QChatGPT/releases",
proxies=network.wrapper_proxies()
)
rls_list = rls_list_resp.json()
return rls_list
def get_current_tag() -> str:
"""获取当前tag"""
current_tag = constants.semantic_version
if os.path.exists("current_tag"):
with open("current_tag", "r") as f:
current_tag = f.read()
return current_tag
def compare_version_str(v0: str, v1: str) -> int: def compare_version_str(v0: str, v1: str) -> int:
"""比较两个版本号""" """比较两个版本号"""
@ -209,79 +172,3 @@ def update_all(cli: bool = False) -> bool:
else: else:
print("已更新到最新版本: {}\n更新日志:\n{}\n完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看。请手动重启程序以使用新版本。".format(current_tag, "\n".join(rls_notes[:-1]))) print("已更新到最新版本: {}\n更新日志:\n{}\n完整的更新日志请前往 https://github.com/RockChinQ/QChatGPT/releases 查看。请手动重启程序以使用新版本。".format(current_tag, "\n".join(rls_notes[:-1])))
return True return True
def is_repo(path: str) -> bool:
"""检查是否是git仓库"""
check_dulwich_closure()
from dulwich import porcelain
try:
porcelain.open_repo(path)
return True
except:
return False
def get_remote_url(repo_path: str) -> str:
"""获取远程仓库地址"""
check_dulwich_closure()
from dulwich import porcelain
repo = porcelain.open_repo(repo_path)
return str(porcelain.get_remote_repo(repo, "origin")[1])
def get_current_version_info() -> str:
"""获取当前版本信息"""
rls_list = get_release_list()
current_tag = get_current_tag()
for rls in rls_list:
if rls['tag_name'] == current_tag:
return rls['name'] + "\n" + rls['body']
return "未知版本"
def is_new_version_available() -> bool:
"""检查是否有新版本"""
# 从github获取release列表
rls_list = get_release_list()
if rls_list is None:
return False
# 获取当前版本
current_tag = get_current_tag()
# 检查是否有新版本
latest_tag_name = ""
for rls in rls_list:
if latest_tag_name == "":
latest_tag_name = rls['tag_name']
break
return is_newer(latest_tag_name, current_tag)
def get_rls_notes() -> list:
"""获取更新日志"""
# 从github获取release列表
rls_list = get_release_list()
if rls_list is None:
return None
# 获取当前版本
current_tag = get_current_tag()
# 检查是否有新版本
rls_notes = []
for rls in rls_list:
if rls['tag_name'] == current_tag:
break
rls_notes.append(rls['name'])
return rls_notes
if __name__ == "__main__":
update_all()

130
pkg/utils/version.py Normal file
View File

@ -0,0 +1,130 @@
from __future__ import annotations
import os
import requests
from ..core import app
from . import constants
class VersionManager:
ap: app.Application
def __init__(
self,
ap: app.Application
):
self.ap = ap
async def initialize(
self
):
pass
def get_current_version(
self
) -> str:
current_tag = constants.semantic_version
if os.path.exists("current_tag"):
with open("current_tag", "r") as f:
current_tag = f.read()
return current_tag
async def get_current_version_info(
self
) -> str:
"""获取当前版本信息"""
rls_list = await self.get_release_list()
current_tag = self.get_current_version()
for rls in rls_list:
if rls['tag_name'] == current_tag:
return rls['name'] + "\n" + rls['body']
return "未知版本"
async def get_release_list(self) -> list:
"""获取发行列表"""
rls_list_resp = requests.get(
url="https://api.github.com/repos/RockChinQ/QChatGPT/releases",
proxies=self.ap.proxy_mgr.get_forward_proxies()
)
rls_list = rls_list_resp.json()
return rls_list
async def update_all(self):
pass
async def is_new_version_available(self) -> bool:
"""检查是否有新版本"""
# 从github获取release列表
rls_list = await self.get_release_list()
if rls_list is None:
return False
# 获取当前版本
current_tag = self.get_current_version()
# 检查是否有新版本
latest_tag_name = ""
for rls in rls_list:
if latest_tag_name == "":
latest_tag_name = rls['tag_name']
break
return self.is_newer(latest_tag_name, current_tag)
def is_newer(self, new_tag: str, old_tag: str):
"""判断版本是否更新,忽略第四位版本和第一位版本"""
if new_tag == old_tag:
return False
new_tag = new_tag.split(".")
old_tag = old_tag.split(".")
# 判断主版本是否相同
if new_tag[0] != old_tag[0]:
return False
if len(new_tag) < 4:
return True
# 合成前三段,判断是否相同
new_tag = ".".join(new_tag[:3])
old_tag = ".".join(old_tag[:3])
return new_tag != old_tag
def compare_version_str(v0: str, v1: str) -> int:
"""比较两个版本号"""
# 删除版本号前的v
if v0.startswith("v"):
v0 = v0[1:]
if v1.startswith("v"):
v1 = v1[1:]
v0:list = v0.split(".")
v1:list = v1.split(".")
# 如果两个版本号节数不同把短的后面用0补齐
if len(v0) < len(v1):
v0.extend(["0"]*(len(v1)-len(v0)))
elif len(v0) > len(v1):
v1.extend(["0"]*(len(v0)-len(v1)))
# 从高位向低位比较
for i in range(len(v0)):
if int(v0[i]) > int(v1[i]):
return 1
elif int(v0[i]) < int(v1[i]):
return -1
return 0

View File

@ -0,0 +1,3 @@
{
"plugins": []
}