feat: 持久化和 web 接口基础架构

This commit is contained in:
RockChinQ 2024-10-11 22:27:53 +08:00
parent 21f153e5c3
commit 7c3557e943
No known key found for this signature in database
GPG Key ID: 8AC0BEFE1743A015
26 changed files with 462 additions and 22 deletions

7
.gitignore vendored
View File

@ -4,8 +4,8 @@ __pycache__/
database.db database.db
qchatgpt.log qchatgpt.log
/banlist.py /banlist.py
plugins/ /plugins/
!plugins/__init__.py !/plugins/__init__.py
/revcfg.py /revcfg.py
prompts/ prompts/
logs/ logs/
@ -34,4 +34,5 @@ bard.json
res/instance_id.json res/instance_id.json
.DS_Store .DS_Store
/data /data
botpy.log* botpy.log*
/poc

13
main.py
View File

@ -13,7 +13,10 @@ asciiart = r"""
""" """
async def main_entry(): import asyncio
async def main_entry(loop: asyncio.AbstractEventLoop):
print(asciiart) print(asciiart)
import sys import sys
@ -46,7 +49,7 @@ async def main_entry():
sys.exit(0) sys.exit(0)
from pkg.core import boot from pkg.core import boot
await boot.main() await boot.main(loop)
if __name__ == '__main__': if __name__ == '__main__':
@ -65,8 +68,8 @@ if __name__ == '__main__':
if invalid_pwd: if invalid_pwd:
print("请在QChatGPT项目根目录下以命令形式运行此程序。") print("请在QChatGPT项目根目录下以命令形式运行此程序。")
input("按任意键退出...") input("按任意键退出...")
exit(0) exit(1)
import asyncio loop = asyncio.new_event_loop()
asyncio.run(main_entry()) loop.run_until_complete(main_entry(loop))

0
pkg/api/__init__.py Normal file
View File

0
pkg/api/http/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,91 @@
from __future__ import annotations
import abc
import typing
import quart
from quart.typing import RouteCallable
from ....core import app
preregistered_groups: list[type[RouterGroup]] = []
"""RouterGroup 的预注册列表"""
def group_class(name: str, path: str) -> None:
"""注册一个 RouterGroup"""
def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]:
cls.name = name
cls.path = path
preregistered_groups.append(cls)
return cls
return decorator
class RouterGroup(abc.ABC):
name: str
path: str
ap: app.Application
quart_app: quart.Quart
def __init__(self, ap: app.Application, quart_app: quart.Quart) -> None:
self.ap = ap
self.quart_app = quart_app
@abc.abstractmethod
async def initialize(self) -> None:
pass
def route(self, rule: str, **options: typing.Any) -> typing.Callable[[RouteCallable], RouteCallable]: # decorator
"""注册一个路由"""
def decorator(f: RouteCallable) -> RouteCallable:
nonlocal rule
rule = self.path + rule
async def handler_error(*args, **kwargs):
try:
return await f(*args, **kwargs)
except Exception as e: # 自动 500
return self.http_status(500, -2, str(e))
new_f = handler_error
new_f.__name__ = f.__name__
new_f.__doc__ = f.__doc__
self.quart_app.route(rule, **options)(new_f)
return f
return decorator
def _cors(self, response: quart.Response) -> quart.Response:
# Quart-Cors 似乎很久没维护了,所以自己写
response.headers['Access-Control-Allow-Origin'] = '*'
response.headers['Access-Control-Allow-Headers'] = '*'
response.headers['Access-Control-Allow-Methods'] = '*'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
def success(self, data: typing.Any = None) -> quart.Response:
"""返回一个 200 响应"""
return self._cors(quart.jsonify({
'code': 0,
'msg': 'ok',
'data': data,
}))
def fail(self, code: int, msg: str) -> quart.Response:
"""返回一个异常响应"""
return self._cors(quart.jsonify({
'code': code,
'msg': msg,
}))
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
"""返回一个指定状态码的响应"""
return self.fail(code, msg), status

View File

@ -0,0 +1,21 @@
from __future__ import annotations
import traceback
import quart
from .....core import app
from .. import group
@group.group_class('log', '/api/v1/log')
class LogRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['GET'])
async def _() -> str:
return self.success(
data={
"logs": self.ap.log_cache.get_all_logs()
}
)

View File

@ -0,0 +1,48 @@
from __future__ import annotations
import asyncio
import quart
from ....core import app
from .groups import log
from . import group
class HTTPController:
ap: app.Application
quart_app: quart.Quart
def __init__(self, ap: app.Application) -> None:
self.ap = ap
self.quart_app = quart.Quart(__name__)
async def initialize(self) -> None:
await self.register_routes()
async def run(self) -> None:
if self.ap.system_cfg.data['http-api']['enable']:
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
asyncio.create_task(self.quart_app.run_task(
host=self.ap.system_cfg.data['http-api']['host'],
port=self.ap.system_cfg.data['http-api']['port'],
shutdown_trigger=shutdown_trigger_placeholder
))
async def register_routes(self) -> None:
@self.quart_app.route('/healthz')
async def healthz():
return {
"code": 0,
"msg": "ok"
}
for g in group.preregistered_groups:
ginst = g(self.ap, self.quart_app)
await ginst.initialize()

View File

View File

@ -17,11 +17,18 @@ from ..plugin import manager as plugin_mgr
from ..pipeline import pool from ..pipeline import pool
from ..pipeline import controller, stagemgr from ..pipeline import controller, stagemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
from ..persistence import mgr as persistencemgr
from ..api.http.controller import main as http_controller
from ..utils import logcache
class Application: class Application:
"""运行时应用对象和上下文""" """运行时应用对象和上下文"""
event_loop: asyncio.AbstractEventLoop = None
asyncio_tasks: list[asyncio.Task] = []
platform_mgr: im_mgr.PlatformManager = None platform_mgr: im_mgr.PlatformManager = None
cmd_mgr: cmdmgr.CommandManager = None cmd_mgr: cmdmgr.CommandManager = None
@ -78,6 +85,12 @@ class Application:
logger: logging.Logger = None logger: logging.Logger = None
persistence_mgr: persistencemgr.PersistenceManager = None
http_ctrl: http_controller.HTTPController = None
log_cache: logcache.LogCache = None
def __init__(self): def __init__(self):
pass pass
@ -91,13 +104,21 @@ class Application:
try: try:
# 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
async def never_ending():
while True:
await asyncio.sleep(1)
tasks = [ tasks = [
asyncio.create_task(self.platform_mgr.run()), asyncio.create_task(self.platform_mgr.run()), # 消息平台
asyncio.create_task(self.ctrl.run()) asyncio.create_task(self.ctrl.run()), # 消息处理循环
asyncio.create_task(self.http_ctrl.run()), # http 接口服务
asyncio.create_task(never_ending())
] ]
self.asyncio_tasks.extend(tasks)
# 挂信号处理 # 挂系统信号处理
import signal import signal
def signal_handler(sig, frame): def signal_handler(sig, frame):
@ -109,7 +130,6 @@ class Application:
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:

View File

@ -1,6 +1,7 @@
from __future__ import print_function from __future__ import print_function
import traceback import traceback
import asyncio
from . import app from . import app
from ..audit import identifier from ..audit import identifier
@ -19,13 +20,15 @@ stage_order = [
] ]
async def make_app() -> app.Application: async def make_app(loop: asyncio.AbstractEventLoop) -> app.Application:
# 生成标识符 # 生成标识符
identifier.init() identifier.init()
ap = app.Application() ap = app.Application()
ap.event_loop = loop
# 执行启动阶段 # 执行启动阶段
for stage_name in stage_order: for stage_name in stage_order:
stage_cls = stage.preregistered_stages[stage_name] stage_cls = stage.preregistered_stages[stage_name]
@ -38,9 +41,9 @@ async def make_app() -> app.Application:
return ap return ap
async def main(): async def main(loop: asyncio.AbstractEventLoop):
try: try:
app_inst = await make_app() app_inst = await make_app(loop)
await app_inst.run() await app_inst.run()
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()

View File

@ -15,6 +15,9 @@ required_deps = {
"psutil": "psutil", "psutil": "psutil",
"async_lru": "async-lru", "async_lru": "async-lru",
"ollama": "ollama", "ollama": "ollama",
"quart": "quart",
"sqlalchemy": "sqlalchemy[asyncio]",
"aiosqlite": "aiosqlite",
} }

View File

@ -15,7 +15,7 @@ log_colors_config = {
} }
async def init_logging() -> logging.Logger: async def init_logging(extra_handlers: list[logging.Handler] = None) -> logging.Logger:
# 删除所有现有的logger # 删除所有现有的logger
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler) logging.root.removeHandler(handler)
@ -41,7 +41,8 @@ async def init_logging() -> logging.Logger:
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
log_handlers: logging.Handler = [stream_handler, logging.FileHandler(log_file_name)] log_handlers: list[logging.Handler] = [stream_handler, logging.FileHandler(log_file_name)]
log_handlers += extra_handlers if extra_handlers is not None else []
for handler in log_handlers: for handler in log_handlers:
handler.setLevel(level) handler.setLevel(level)

View File

@ -0,0 +1,30 @@
from __future__ import annotations
from .. import migration
@migration.migration_class("http-api-config", 13)
class HttpApiConfigMigration(migration.Migration):
"""迁移"""
async def need_migrate(self) -> bool:
"""判断当前环境是否需要运行此迁移"""
return 'http-api' not in self.ap.system_cfg.data or "persistence" not in self.ap.system_cfg.data
async def run(self):
"""执行迁移"""
self.ap.system_cfg.data['http-api'] = {
"enable": True,
"host": "0.0.0.0",
"port": 5300
}
self.ap.system_cfg.data['persistence'] = {
"sqlite": {
"path": "data/persistence.db"
},
"use": "sqlite"
}
await self.ap.system_cfg.dump_config()

View File

@ -15,6 +15,10 @@ from ...provider.sysprompt import sysprompt as llm_prompt_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...provider import runnermgr from ...provider import runnermgr
from ...platform import manager as im_mgr from ...platform import manager as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
from ...utils import logcache
@stage.stage_class("BuildAppStage") @stage.stage_class("BuildAppStage")
class BuildAppStage(stage.BootingStage): class BuildAppStage(stage.BootingStage):
@ -58,6 +62,13 @@ class BuildAppStage(stage.BootingStage):
ap.query_pool = pool.QueryPool() ap.query_pool = pool.QueryPool()
log_cache = logcache.LogCache()
ap.log_cache = log_cache
persistence_mgr_inst = persistencemgr.PersistenceManager(ap)
await persistence_mgr_inst.initialize()
ap.persistence_mgr = persistence_mgr_inst
plugin_mgr_inst = plugin_mgr.PluginManager(ap) plugin_mgr_inst = plugin_mgr.PluginManager(ap)
await plugin_mgr_inst.initialize() await plugin_mgr_inst.initialize()
ap.plugin_mgr = plugin_mgr_inst ap.plugin_mgr = plugin_mgr_inst
@ -95,6 +106,9 @@ class BuildAppStage(stage.BootingStage):
await stage_mgr.initialize() await stage_mgr.initialize()
ap.stage_mgr = stage_mgr ap.stage_mgr = stage_mgr
http_ctrl = http_controller.HTTPController(ap)
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl
ctrl = controller.Controller(ap) ctrl = controller.Controller(ap)
ap.ctrl = ctrl ap.ctrl = ctrl

View File

@ -6,7 +6,7 @@ from .. import stage, app
from .. import migration from .. import migration
from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion from ..migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg from ..migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config from ..migrations import m010_ollama_requester_config, m011_command_prefix_config, m012_runner_config, m013_http_api_config
@stage.stage_class("MigrationStage") @stage.stage_class("MigrationStage")

View File

@ -1,9 +1,38 @@
from __future__ import annotations from __future__ import annotations
import logging
import asyncio
from datetime import datetime
from .. import stage, app from .. import stage, app
from ..bootutils import log from ..bootutils import log
class PersistenceHandler(logging.Handler, object):
"""
保存日志到数据库
"""
ap: app.Application
def __init__(self, name, ap: app.Application):
logging.Handler.__init__(self)
self.ap = ap
def emit(self, record):
"""
emit函数为自定义handler类时必重写的函数这里可以根据需要对日志消息做一些处理比如发送日志到服务器
发出记录(Emit a record)
"""
try:
msg = self.format(record)
if self.ap.log_cache is not None:
self.ap.log_cache.add_log(msg)
except Exception:
self.handleError(record)
@stage.stage_class("SetupLoggerStage") @stage.stage_class("SetupLoggerStage")
class SetupLoggerStage(stage.BootingStage): class SetupLoggerStage(stage.BootingStage):
"""设置日志器阶段 """设置日志器阶段
@ -12,4 +41,9 @@ class SetupLoggerStage(stage.BootingStage):
async def run(self, ap: app.Application): async def run(self, ap: app.Application):
"""启动 """启动
""" """
ap.logger = await log.init_logging() persistence_handler = PersistenceHandler('LoggerHandler', ap)
extra_handlers = []
extra_handlers = [persistence_handler]
ap.logger = await log.init_logging(extra_handlers)

View File

View File

@ -0,0 +1,40 @@
from __future__ import annotations
import abc
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
from ..core import app
preregistered_managers: list[type[BaseDatabaseManager]] = []
def manager_class(name: str) -> None:
"""注册一个数据库管理类"""
def decorator(cls: type[BaseDatabaseManager]) -> type[BaseDatabaseManager]:
cls.name = name
preregistered_managers.append(cls)
return cls
return decorator
class BaseDatabaseManager(abc.ABC):
"""基础数据库管理类"""
name: str
ap: app.Application
engine: sqlalchemy_asyncio.AsyncEngine
def __init__(self, ap: app.Application) -> None:
self.ap = ap
@abc.abstractmethod
async def initialize(self) -> None:
pass
def get_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.engine

View File

View File

@ -0,0 +1,13 @@
from __future__ import annotations
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
from .. import database
@database.manager_class("sqlite")
class SQLiteDatabaseManager(database.BaseDatabaseManager):
"""SQLite 数据库管理类"""
async def initialize(self) -> None:
self.engine = sqlalchemy_asyncio.create_async_engine(f"sqlite+aiosqlite:///{self.ap.system_cfg.data['persistence']['sqlite']['path']}")

55
pkg/persistence/mgr.py Normal file
View File

@ -0,0 +1,55 @@
from __future__ import annotations
import asyncio
import datetime
import sqlalchemy.ext.asyncio as sqlalchemy_asyncio
import sqlalchemy
from . import database
from ..core import app
from .databases import sqlite
class PersistenceManager:
"""持久化模块管理器"""
ap: app.Application
db: database.BaseDatabaseManager
"""数据库管理器"""
meta: sqlalchemy.MetaData
def __init__(self, ap: app.Application):
self.ap = ap
self.meta = sqlalchemy.MetaData()
async def initialize(self):
for manager in database.preregistered_managers:
self.db = manager(self.ap)
await self.db.initialize()
await self.create_tables()
async def create_tables(self):
# TODO: 对扩展友好
# 日志
async with self.get_db_engine().connect() as conn:
await conn.run_sync(self.meta.create_all)
await conn.commit()
async def execute_async(
self,
*args,
**kwargs
):
async with self.get_db_engine().connect() as conn:
await conn.execute(*args, **kwargs)
await conn.commit()
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine()

49
pkg/utils/logcache.py Normal file
View File

@ -0,0 +1,49 @@
from __future__ import annotations
import pydantic
LOG_PAGE_SIZE = 20
MAX_CACHED_PAGES = 10
class LogPage(pydantic.BaseModel):
"""日志页"""
cached_count: int = 0
logs: str = ""
def add_log(self, log: str) -> bool:
"""添加日志
Returns:
bool: 是否已满
"""
self.logs += log
self.cached_count += 1
return self.cached_count >= LOG_PAGE_SIZE
class LogCache:
"""由于 logger 是同步的,但实例中的数据库操作是异步的;
同时持久化的日志信息已经写入文件了故做一个缓存来为前端提供日志查询服务"""
log_pages: list[LogPage] = []
"""从前到后,越新的日志页越靠后"""
def __init__(self):
self.log_pages = []
self.log_pages.append(LogPage())
def add_log(self, log: str):
"""添加日志"""
if self.log_pages[-1].add_log(log):
self.log_pages.append(LogPage())
if len(self.log_pages) > MAX_CACHED_PAGES:
self.log_pages.pop(0)
def get_all_logs(self) -> str:
"""获取所有日志"""
return "".join([page.logs for page in self.log_pages])

View File

@ -14,4 +14,7 @@ websockets
urllib3 urllib3
psutil psutil
async-lru async-lru
ollama ollama
quart
sqlalchemy[asyncio]
aiosqlite

View File

@ -11,5 +11,16 @@
}, },
"pipeline-concurrency": 20, "pipeline-concurrency": 20,
"qcg-center-url": "https://api.qchatgpt.rockchin.top/api/v2", "qcg-center-url": "https://api.qchatgpt.rockchin.top/api/v2",
"help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接https://q.rkcn.top" "help-message": "QChatGPT - 😎高稳定性、🧩支持插件、🌏实时联网的 ChatGPT QQ 机器人🤖\n链接https://q.rkcn.top",
"http-api": {
"enable": true,
"host": "0.0.0.0",
"port": 5300
},
"persistence": {
"sqlite": {
"path": "data/persistence.db"
},
"use": "sqlite"
}
} }