QChatGPT/pkg/utils/context.py

102 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import threading
from . import threadctl
from ..database import manager as db_mgr
from ..openai import manager as openai_mgr
from ..qqbot import manager as qqbot_mgr
from ..plugin import host as plugin_host
context = {
'inst': {
'database.manager.DatabaseManager': None,
'openai.manager.OpenAIInteract': None,
'qqbot.manager.QQBotManager': None,
},
'pool_ctl': None,
'logger_handler': None,
'config': None,
'plugin_host': None,
}
context_lock = threading.Lock()
### context耦合度非常高需要大改 ###
def set_config(inst):
context_lock.acquire()
context['config'] = inst
context_lock.release()
def get_config():
context_lock.acquire()
t = context['config']
context_lock.release()
return t
def set_database_manager(inst: db_mgr.DatabaseManager):
context_lock.acquire()
context['inst']['database.manager.DatabaseManager'] = inst
context_lock.release()
def get_database_manager() -> db_mgr.DatabaseManager:
context_lock.acquire()
t = context['inst']['database.manager.DatabaseManager']
context_lock.release()
return t
def set_openai_manager(inst: openai_mgr.OpenAIInteract):
context_lock.acquire()
context['inst']['openai.manager.OpenAIInteract'] = inst
context_lock.release()
def get_openai_manager() -> openai_mgr.OpenAIInteract:
context_lock.acquire()
t = context['inst']['openai.manager.OpenAIInteract']
context_lock.release()
return t
def set_qqbot_manager(inst: qqbot_mgr.QQBotManager):
context_lock.acquire()
context['inst']['qqbot.manager.QQBotManager'] = inst
context_lock.release()
def get_qqbot_manager() -> qqbot_mgr.QQBotManager:
context_lock.acquire()
t = context['inst']['qqbot.manager.QQBotManager']
context_lock.release()
return t
def set_plugin_host(inst: plugin_host.PluginHost):
context_lock.acquire()
context['plugin_host'] = inst
context_lock.release()
def get_plugin_host() -> plugin_host.PluginHost:
context_lock.acquire()
t = context['plugin_host']
context_lock.release()
return t
def set_thread_ctl(inst: threadctl.ThreadCtl):
context_lock.acquire()
context['pool_ctl'] = inst
context_lock.release()
def get_thread_ctl() -> threadctl.ThreadCtl:
context_lock.acquire()
t: threadctl.ThreadCtl = context['pool_ctl']
context_lock.release()
return t